76 const circle::Tensor *input;
77 const circle::Tensor *position;
78 const circle::Tensor *output;
81 uint8_t *position_data;
84 const circle::GatherOptions *options;
92 input = runtime_kernel.
inputs[inputTensorIdx];
93 position = runtime_kernel.
inputs[positionsTensorIdx];
94 output = runtime_kernel.
outputs[outputTensorIdx];
95 assert(input !=
nullptr);
96 assert(position !=
nullptr);
97 assert(output !=
nullptr);
99 status = runtime_kernel.
getDataFromStorage(op_index, runtime_storage, runtime_context);
103 input_data = runtime_kernel.
inputs_data[inputTensorIdx];
104 position_data = runtime_kernel.
inputs_data[positionsTensorIdx];
105 output_data = runtime_kernel.
outputs_data[outputTensorIdx];
106 assert(input_data !=
nullptr);
107 assert(position_data !=
nullptr);
108 assert(output_data !=
nullptr);
110 options = runtime_kernel.
first_operator->builtin_options_as_GatherOptions();
119 int axis = options->axis();
122 axis += input_dims_size;
125 int batch_dims = options->batch_dims();
131 batch_dims += coords_dims_size;
134 const int axis_size = input_shape.
dims(axis);
137 for (
int i = 0; i < batch_dims; ++i)
139 batch_size *= input_shape.
dims(i);
142 for (
int i = batch_dims; i < axis; ++i)
144 outer_size *= input_shape.
dims(i);
147 for (
int i = axis + 1; i < input_dims_size; ++i)
149 inner_size *= input_shape.
dims(i);
152 for (
int i = batch_dims; i < coords_dims_size; ++i)
154 coord_size *= position_shape.
dims(i);
157 switch (input->type())
160 case circle::TensorType_FLOAT32:
162 gather<float, int32_t>(utils::castInputData<float>(input_data),
163 utils::castInputData<int32_t>(position_data),
164 utils::castOutputData<float>(output_data), axis_size, batch_size,
165 outer_size, inner_size, coord_size);
170 case circle::TensorType_INT8:
172 gather<int8_t, int32_t>(utils::castInputData<int8_t>(input_data),
173 utils::castInputData<int32_t>(position_data),
174 utils::castOutputData<int8_t>(output_data), axis_size, batch_size,
175 outer_size, inner_size, coord_size);
179 case circle::TensorType_INT32:
181 gather<int32_t, int32_t>(utils::castInputData<int32_t>(input_data),
182 utils::castInputData<int32_t>(position_data),
183 utils::castOutputData<int32_t>(output_data), axis_size, batch_size,
184 outer_size, inner_size, coord_size);
190 assert(
false &&
"Unsupported type.");