33constexpr uint32_t inputTensorIdx = 0;
34constexpr uint32_t positionsTensorIdx = 1;
38template <
typename InputT,
typename CoordsT =
int32_t>
39void gather(
const InputT *input_data,
const CoordsT *coords_data, InputT *output_data,
40 int32_t axis_size, int32_t batch_size, int32_t outer_size, int32_t inner_size,
44 for (
int batch = 0; batch < batch_size; ++batch)
46 for (
int outer = 0; outer < outer_size; ++outer)
48 for (
int coord = 0; coord < coord_size; ++coord)
50 auto x = coords_data[coord];
52 output_data + (((batch * outer_size) + outer) * coord_size + coord) * inner_size,
54 (((batch * outer_size) + outer) * axis_size + coords_data[batch * coord_size + coord]) *
56 sizeof(InputT) * inner_size);
71 const circle::Tensor *
input;
72 const circle::Tensor *position;
73 const circle::Tensor *
output;
76 uint8_t *position_data;
79 const circle::GatherOptions *
options;
88 position = runtime_kernel.
inputs[positionsTensorIdx];
90 assert(input !=
nullptr);
91 assert(position !=
nullptr);
92 assert(output !=
nullptr);
94 status = runtime_kernel.
getDataFromStorage(op_index, runtime_storage, runtime_context);
99 position_data = runtime_kernel.
inputs_data[positionsTensorIdx];
101 assert(input_data !=
nullptr);
102 assert(position_data !=
nullptr);
103 assert(output_data !=
nullptr);
113 const int input_dims_size = input_shape.dimensionsCount();
117 axis += input_dims_size;
120 int batch_dims =
options->batch_dims();
123 const int coords_dims_size = position_shape.dimensionsCount();
126 batch_dims += coords_dims_size;
129 const int axis_size = input_shape.dims(axis);
132 for (
int i = 0; i < batch_dims; ++i)
134 batch_size *= input_shape.dims(i);
137 for (
int i = batch_dims; i < axis; ++i)
139 outer_size *= input_shape.dims(i);
142 for (
int i = axis + 1; i < input_dims_size; ++i)
144 inner_size *= input_shape.dims(i);
147 for (
int i = batch_dims; i < coords_dims_size; ++i)
149 coord_size *= position_shape.dims(i);
152 switch (
input->type())
155 case circle::TensorType_FLOAT32:
157 gather<float, int32_t>(utils::castInputData<float>(input_data),
158 utils::castInputData<int32_t>(position_data),
159 utils::castOutputData<float>(output_data), axis_size, batch_size,
160 outer_size, inner_size, coord_size);
165 case circle::TensorType_INT8:
167 gather<int8_t, int32_t>(utils::castInputData<int8_t>(input_data),
168 utils::castInputData<int32_t>(position_data),
169 utils::castOutputData<int8_t>(output_data), axis_size, batch_size,
170 outer_size, inner_size, coord_size);
174 case circle::TensorType_INT32:
176 gather<int32_t, int32_t>(utils::castInputData<int32_t>(input_data),
177 utils::castInputData<int32_t>(position_data),
178 utils::castOutputData<int32_t>(output_data), axis_size, batch_size,
179 outer_size, inner_size, coord_size);
185 assert(
false &&
"Unsupported type.");
uint8_t * outputs_data[maxOutputSize]
const circle::Operator * first_operator
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
uint8_t * inputs_data[maxInputSize]
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
constexpr uint32_t outputTensorIdx
core::OMRuntimeContext & runtime_context
core::OMRuntimeStorage & runtime_storage