ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Gather.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "OMStatus.h"
18
19#include "core/OMUtils.h"
20#include "core/OMKernelData.h"
21
23#include "execute/OMUtils.h"
25
26using namespace onert_micro;
27using namespace onert_micro::core;
28using namespace onert_micro::execute;
29
30namespace
31{
32
33constexpr uint32_t inputTensorIdx = 0;
34constexpr uint32_t positionsTensorIdx = 1;
35
36constexpr uint32_t outputTensorIdx = 0;
37
38template <typename InputT, typename CoordsT = int32_t>
39OMStatus gather(const InputT *input_data, const CoordsT *coords_data, InputT *output_data,
40 int32_t axis_size, int32_t batch_size, int32_t outer_size, int32_t inner_size,
41 int32_t coord_size)
42{
43
44 for (int batch = 0; batch < batch_size; ++batch)
45 {
46 for (int outer = 0; outer < outer_size; ++outer)
47 {
48 for (int coord = 0; coord < coord_size; ++coord)
49 {
50 auto x = coords_data[batch * coord_size + coord];
51
52 // Bounds check: index must be in range [0, axis_size)
53 if (x < 0 || x >= axis_size)
54 {
55 return IndexError;
56 }
57
58 std::memcpy(
59 output_data + (((batch * outer_size) + outer) * coord_size + coord) * inner_size,
60 input_data +
61 (((batch * outer_size) + outer) * axis_size + coords_data[batch * coord_size + coord]) *
62 inner_size,
63 sizeof(InputT) * inner_size);
64 }
65 }
66 }
67
68 return Ok;
69}
70
71} // namespace
72
73// NOTE: doesn't currently support dynamic shapes
74namespace onert_micro
75{
76namespace execute
77{
78
80{
81 core::OMRuntimeContext &runtime_context = execute_args.runtime_context;
82 core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage;
83 uint16_t op_index = execute_args.kernel_index;
84
85 const circle::Tensor *input;
86 const circle::Tensor *position;
87 const circle::Tensor *output;
88
89 uint8_t *input_data;
90 uint8_t *position_data;
91 uint8_t *output_data;
92
93 const circle::GatherOptions *options;
94 // Read kernel
95 {
96 execute::OMRuntimeKernel runtime_kernel;
97 OMStatus status = runtime_kernel.readKernel(op_index, runtime_context);
98 if (status != Ok)
99 return status;
100
101 input = runtime_kernel.inputs[inputTensorIdx];
102 position = runtime_kernel.inputs[positionsTensorIdx];
103 output = runtime_kernel.outputs[outputTensorIdx];
104 assert(input != nullptr);
105 assert(position != nullptr);
106 assert(output != nullptr);
107
108 status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context);
109 if (status != Ok)
110 return status;
111
112 input_data = runtime_kernel.inputs_data[inputTensorIdx];
113 position_data = runtime_kernel.inputs_data[positionsTensorIdx];
114 output_data = runtime_kernel.outputs_data[outputTensorIdx];
115 assert(input_data != nullptr);
116 assert(position_data != nullptr);
117 assert(output_data != nullptr);
118
119 options = runtime_kernel.first_operator->builtin_options_as_GatherOptions();
120 }
121
122 OMStatus status = Ok;
123
124 OMRuntimeShape position_shape(position);
125 OMRuntimeShape input_shape(input);
126
127 const int input_dims_size = input_shape.dimensionsCount();
128 int axis = options->axis();
129 if (axis < 0)
130 {
131 axis += input_dims_size;
132 }
133
134 int batch_dims = options->batch_dims();
135 // batch_dims should be in range: [-rank(coords), rank(coords)].
136 // Negative batch_dims is added with rank of coords.
137 const int coords_dims_size = position_shape.dimensionsCount();
138 if (batch_dims < 0)
139 {
140 batch_dims += coords_dims_size;
141 }
142
143 const int axis_size = input_shape.dims(axis);
144
145 int batch_size = 1;
146 for (int i = 0; i < batch_dims; ++i)
147 {
148 batch_size *= input_shape.dims(i);
149 }
150 int outer_size = 1;
151 for (int i = batch_dims; i < axis; ++i)
152 {
153 outer_size *= input_shape.dims(i);
154 }
155 int inner_size = 1;
156 for (int i = axis + 1; i < input_dims_size; ++i)
157 {
158 inner_size *= input_shape.dims(i);
159 }
160 int coord_size = 1;
161 for (int i = batch_dims; i < coords_dims_size; ++i)
162 {
163 coord_size *= position_shape.dims(i);
164 }
165
166 switch (input->type())
167 {
168#ifndef DIS_FLOAT
169 case circle::TensorType_FLOAT32:
170 {
171 status = gather<float, int32_t>(utils::castInputData<float>(input_data),
172 utils::castInputData<int32_t>(position_data),
173 utils::castOutputData<float>(output_data), axis_size,
174 batch_size, outer_size, inner_size, coord_size);
175 }
176 break;
177#endif // DIS_FLOAT
178#ifndef DIS_QUANT
179 case circle::TensorType_INT8:
180 {
181 status = gather<int8_t, int32_t>(utils::castInputData<int8_t>(input_data),
182 utils::castInputData<int32_t>(position_data),
183 utils::castOutputData<int8_t>(output_data), axis_size,
184 batch_size, outer_size, inner_size, coord_size);
185 }
186 break;
187#endif // DIS_QUANT
188 case circle::TensorType_INT32:
189 {
190 status = gather<int32_t, int32_t>(utils::castInputData<int32_t>(input_data),
191 utils::castInputData<int32_t>(position_data),
192 utils::castOutputData<int32_t>(output_data), axis_size,
193 batch_size, outer_size, inner_size, coord_size);
194 }
195 break;
196 default:
197 {
198 status = UnsupportedActivation;
199 assert(false && "Unsupported type.");
200 }
201 }
202
203 return status;
204}
205
206} // namespace execute
207} // namespace onert_micro
size_t dimensionsCount() const noexcept
uint8_t * outputs_data[maxOutputSize]
const circle::Operator * first_operator
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
constexpr uint32_t outputTensorIdx
OMStatus execute_kernel_CircleGather(const OMExecuteArgs &execute_args)
Definition Gather.cpp:79
@ UnsupportedActivation
Definition OMStatus.h:28
core::OMRuntimeContext & runtime_context
core::OMRuntimeStorage & runtime_storage