169{
170 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, indices);
171 ARM_COMPUTE_ERROR_ON(indices->info()->num_dimensions() > 3);
172 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32, DataType::S32);
173 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(
174 input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16,
175 DataType::U32, DataType::S32, DataType::F16, DataType::F32);
176
178 _indices = indices;
180 _axis = axis;
181 _indices_rank = indices->info()->num_dimensions();
182
183 if (_axis < 0)
184 {
185 _axis +=
input->info()->num_dimensions();
186 }
187 ARM_COMPUTE_ERROR_ON(0 > _axis || _axis >=
static_cast<int32_t
>(
input->info()->num_dimensions()));
188
189 if (0 == _axis)
190 {
191 switch (_indices->info()->data_type())
192 {
193 case DataType::U32:
194 _func = &NEGatherKernelEx::gather_0_axis<uint32_t>;
195 break;
196 case DataType::S32:
197 _func = &NEGatherKernelEx::gather_0_axis<int32_t>;
198 break;
199 default:
200 ARM_COMPUTE_ERROR("Not supported");
201 break;
202 }
203 }
204 else
205 {
206 switch (_indices->info()->data_type())
207 {
208 case DataType::U32:
209 _func = &NEGatherKernelEx::gather_n_axis<uint32_t>;
210 break;
211 case DataType::S32:
212 _func = &NEGatherKernelEx::gather_n_axis<int32_t>;
213 break;
214 default:
215 ARM_COMPUTE_ERROR("Not supported");
216 break;
217 }
218 }
219
221 input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis);
223
224
225 Window win = calculate_max_window(*
output->info(), Steps());
226 output->info()->set_valid_region(ValidRegion(Coordinates(),
output->info()->tensor_shape()));
227
228 INEKernel::configure(win);
229}
const luci_interpreter::RuntimeShape output_shape
::nncc::core::ADT::tensor::Shape TensorShape
TensorShape compute_gather_shape_ex(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis)