19#include "kernels/Utils.h"
29template <
typename InputT,
typename CoordsT =
int32_t>
30void gather(
const circle::GatherOptions *options, kernels::TISOKernel *kernel)
32 kernels::TISOData tiso_data = kernel->readData();
34 const InputT *
input_data = kernels::getTensorData<InputT>(tiso_data.input1_data);
35 const CoordsT *coords_data = kernels::getTensorData<CoordsT>(tiso_data.input2_data);
36 InputT *
output_data = kernels::getTensorData<InputT>(tiso_data.output_data);
38 const circle::Tensor *
input = kernel->input1();
39 const circle::Tensor *
coords = kernel->input2();
41 const int input_dims_size = Tensor::num_dims(input);
45 axis += input_dims_size;
48 int batch_dims =
options->batch_dims();
51 const int coords_dims_size = Tensor::num_dims(
coords);
54 batch_dims += coords_dims_size;
60 for (
int i = 0; i < batch_dims; ++i)
65 for (
int i = batch_dims; i < axis; ++i)
70 for (
int i = axis + 1; i < input_dims_size; ++i)
75 for (
int i = batch_dims; i < coords_dims_size; ++i)
80 for (
int batch = 0; batch < batch_size; ++batch)
82 for (
int outer = 0; outer < outer_size; ++outer)
84 for (
int coord = 0; coord < coord_size; ++coord)
86 auto x = coords_data[coord];
88 output_data + (((batch * outer_size) + outer) * coord_size + coord) * inner_size,
90 (((batch * outer_size) + outer) * axis_size + coords_data[batch * coord_size + coord]) *
92 sizeof(InputT) * inner_size);
104 const auto *options = cur_op->builtin_options_as_GatherOptions();
108 Tensor::element_type(kernel.
input1()) == DataType::S8 or
109 Tensor::element_type(kernel.
input1()) == DataType::S32);
111 int32_t axis = options->axis();
112 int32_t num_dims = Tensor::num_dims(kernel.
input1());
120 int32_t batch_dims = options->batch_dims();
121 int32_t coords_num_dims = Tensor::num_dims(kernel.
input2());
126 batch_dims += coords_num_dims;
131 for (
int i = 0; i < batch_dims; ++i)
141 const auto *options = cur_op->builtin_options_as_GatherOptions();
143 switch (Tensor::element_type(kernel.
input1()))
146 case DataType::FLOAT32:
147 return gather<float, int32_t>(options, &kernel);
151 return gather<int8_t, int32_t>(options, &kernel);
154 return gather<int32_t, int32_t>(options, &kernel);
156 assert(
false &&
"Unsupported type");
Array< CornerBox > coords
const circle::Tensor * input2() const
const circle::Tensor * input1() const
#define LUCI_INTERPRETER_CHECK(cond)
void configure_kernel_CircleGather(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
void execute_kernel_CircleGather(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
const loco::Dimension & dim(uint32_t axis) const