85 const circle::Tensor *input;
86 const circle::Tensor *position;
87 const circle::Tensor *output;
90 uint8_t *position_data;
93 const circle::GatherOptions *options;
101 input = runtime_kernel.
inputs[inputTensorIdx];
102 position = runtime_kernel.
inputs[positionsTensorIdx];
103 output = runtime_kernel.
outputs[outputTensorIdx];
104 assert(input !=
nullptr);
105 assert(position !=
nullptr);
106 assert(output !=
nullptr);
108 status = runtime_kernel.
getDataFromStorage(op_index, runtime_storage, runtime_context);
112 input_data = runtime_kernel.
inputs_data[inputTensorIdx];
113 position_data = runtime_kernel.
inputs_data[positionsTensorIdx];
114 output_data = runtime_kernel.
outputs_data[outputTensorIdx];
115 assert(input_data !=
nullptr);
116 assert(position_data !=
nullptr);
117 assert(output_data !=
nullptr);
119 options = runtime_kernel.
first_operator->builtin_options_as_GatherOptions();
128 int axis = options->axis();
131 axis += input_dims_size;
134 int batch_dims = options->batch_dims();
140 batch_dims += coords_dims_size;
143 const int axis_size = input_shape.
dims(axis);
146 for (
int i = 0; i < batch_dims; ++i)
148 batch_size *= input_shape.
dims(i);
151 for (
int i = batch_dims; i < axis; ++i)
153 outer_size *= input_shape.
dims(i);
156 for (
int i = axis + 1; i < input_dims_size; ++i)
158 inner_size *= input_shape.
dims(i);
161 for (
int i = batch_dims; i < coords_dims_size; ++i)
163 coord_size *= position_shape.
dims(i);
166 switch (input->type())
169 case circle::TensorType_FLOAT32:
171 status = gather<float, int32_t>(utils::castInputData<float>(input_data),
172 utils::castInputData<int32_t>(position_data),
173 utils::castOutputData<float>(output_data), axis_size,
174 batch_size, outer_size, inner_size, coord_size);
179 case circle::TensorType_INT8:
181 status = gather<int8_t, int32_t>(utils::castInputData<int8_t>(input_data),
182 utils::castInputData<int32_t>(position_data),
183 utils::castOutputData<int8_t>(output_data), axis_size,
184 batch_size, outer_size, inner_size, coord_size);
188 case circle::TensorType_INT32:
190 status = gather<int32_t, int32_t>(utils::castInputData<int32_t>(input_data),
191 utils::castInputData<int32_t>(position_data),
192 utils::castOutputData<int32_t>(output_data), axis_size,
193 batch_size, outer_size, inner_size, coord_size);
199 assert(
false &&
"Unsupported type.");