ONE - On-device Neural Engine
Loading...
Searching...
No Matches
GatherLayer.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "GatherLayer.h"
18
19#include "GGMLHelper.h"
20#include "OperationUtils.h"
21#include "../KernelGenerator.h"
22#include "../Validator.h"
23
25{
26
27void Validator::visit(const ir::operation::Gather &node)
28{
29 using ir::operation::Gather;
30
31 const auto input_index{node.getInputs().at(Gather::Input::INPUT)};
32 const auto input_node = &_graph.operands().at(input_index);
33
34 _supported = false;
35
36 if (input_node->typeInfo().type() != ir::DataType::QUANT_GGML_Q4_0)
37 return;
38
39 _supported = true;
40}
41
42void KernelGenerator::visit(const ir::operation::Gather &node)
43{
44 const auto output_index{node.getOutputs().at(0)};
45 const auto input_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
46 const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
47
48 auto output_tensor = _tensor_reg->getPortableTensor(output_index);
49 auto input_tensor = _tensor_reg->getPortableTensor(input_index);
50 auto indices_tensor = _tensor_reg->getPortableTensor(indices_index);
51
52 const auto rank = _ctx.at(input_index).shape().rank();
53 const auto axis = ops::getAxis(rank, node.param().axis);
54
55 auto fn = std::make_unique<ops::GatherLayer>();
56
57 fn->configure(input_tensor, indices_tensor, output_tensor, axis, _external_context.get());
58
59 _return_fn = std::move(fn);
60}
61
62} // namespace onert::backend::ggml
63
65{
66
67void GatherLayer::configure(const IPortableTensor *input, const IPortableTensor *indices,
68 IPortableTensor *output, int32_t axis, ExternalContext *ctx)
69{
70 _input = input;
71 _indices = indices;
72 _axis = axis;
73 _output = output;
74 _ctx = ctx;
75}
76
77void GatherLayer::runByGGMLQuantInputType()
78{
79 // Supporting condition
80 // Input: rank 2
81 // Indice: rank < 4 or rank 4 with dim(0) = 1, INT32
82 // Axis: 0
83 if (_input->getShape().rank() != 2)
84 throw std::runtime_error("Gather: block quantized input tensor must be rank 2");
85
86 if (_indices->getShape().rank() >= 4 &&
87 (_indices->getShape().rank() != 4 || _indices->getShape().dim(0) != 1))
88 throw std::runtime_error("Gather: invalid indices tensor shape");
89
90 if (_indices->data_type() != ir::DataType::INT32)
91 throw std::runtime_error("Gather: indices tensor must be int32 type");
92
93 if (_axis != 0)
94 throw std::runtime_error("Gather: axis must be 0");
95
96 // convert tensor
97 auto input = getGGMLTensor(_input);
98 auto indices = getGGMLTensor(_indices);
99 auto output = getGGMLTensor(_output);
100 {
101 output.op = GGML_OP_GET_ROWS;
102 output.src[0] = &input;
103 output.src[1] = &indices;
104 }
105 auto *nodes = &output;
106
107 // create graph
108 struct ggml_cgraph graph;
109 {
110 memset(&graph, 0, sizeof(graph));
111 graph.n_nodes = 1;
112 graph.nodes = &nodes;
113 }
114
115 // get cplan
116 auto cplan = ggml_graph_plan(&graph, _ctx->maxNumThreads());
117 std::vector<uint8_t> buf(cplan.work_size);
118 cplan.work_data = buf.data();
119
120 // compute
121 ggml_graph_compute(&graph, &cplan);
122}
123
125{
126 switch (_input->data_type())
127 {
128 case ir::DataType::QUANT_GGML_Q4_0:
129 runByGGMLQuantInputType();
130 break;
131 default:
132 throw std::runtime_error("Gather: unsupported input data type");
133 }
134}
135
136} // namespace onert::backend::ggml::ops
A tensor class that is portable for other backends.
ir::DataType data_type() const override final
ir::Shape getShape() const override final
Get ir::Shape of tensor.
std::unique_ptr< exec::IFunction > _return_fn
void configure(const IPortableTensor *input, const IPortableTensor *indices, IPortableTensor *output, int32_t axis, ExternalContext *ctx)
const Operands & operands() const override
Definition Graph.h:103
const Object & at(const Index &index) const
Get the object that is associated with the given index.
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
struct ggml_tensor getGGMLTensor(const IPortableTensor *tensor)
Definition GGMLHelper.cc:41