ONE - On-device Neural Engine
Loading...
Searching...
No Matches
arm_compute::NEGatherKernelEx Class Reference

#include <NEGatherKernelEx.h>

Collaboration diagram for arm_compute::NEGatherKernelEx:

Public Member Functions

 NEGatherKernelEx ()
 
 NEGatherKernelEx (const NEGatherKernelEx &)=delete
 
NEGatherKernelExoperator= (const NEGatherKernelEx &)=delete
 
 NEGatherKernelEx (NEGatherKernelEx &&)=default
 
NEGatherKernelExoperator= (NEGatherKernelEx &&)=default
 
 ~NEGatherKernelEx ()=default
 
const char * name () const override
 
void configure (const ITensor *input, const ITensor *indices, ITensor *output, int axis=0)
 
void run (const Window &window, const ThreadInfo &info) override
 

Static Public Member Functions

static Status validate (const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis)
 

Detailed Description

Kernel to perform other operation on NEON

Definition at line 52 of file NEGatherKernelEx.h.

Constructor & Destructor Documentation

◆ NEGatherKernelEx() [1/3]

arm_compute::NEGatherKernelEx::NEGatherKernelEx ( )

Default constructor.

Definition at line 76 of file NEGatherKernelEx.cpp.

77 : _input{}, _indices{}, _axis{}, _indices_rank{}, _output{}, _func{}
78{
79}

◆ NEGatherKernelEx() [2/3]

arm_compute::NEGatherKernelEx::NEGatherKernelEx ( const NEGatherKernelEx )
delete

Prevent instances of this class from being copied (As this class contains pointers).

◆ NEGatherKernelEx() [3/3]

arm_compute::NEGatherKernelEx::NEGatherKernelEx ( NEGatherKernelEx &&  )
default

Allow instances of this class to be moved.

◆ ~NEGatherKernelEx()

arm_compute::NEGatherKernelEx::~NEGatherKernelEx ( )
default

Default detructor

Member Function Documentation

◆ configure()

void arm_compute::NEGatherKernelEx::configure ( const ITensor *  input,
const ITensor *  indices,
ITensor *  output,
int  axis = 0 
)

Initialise the kernel's inputs and outputs

Parameters
[in]inputSource tensor. Supported tensor rank: up to 4. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
[in]indicesIndices tensor. Supported tensor rank: up to 3. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[axis])
[out]outputDestination tensor. Data type supported: Same as input
[in]axis(Optional) The axis in input to gather indices from. Negative values wrap around. Defaults to 0

Definition at line 167 of file NEGatherKernelEx.cpp.

169{
170 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, indices);
171 ARM_COMPUTE_ERROR_ON(indices->info()->num_dimensions() > 3);
172 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32, DataType::S32);
173 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(
174 input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16,
175 DataType::U32, DataType::S32, DataType::F16, DataType::F32);
176
177 _input = input;
178 _indices = indices;
179 _output = output;
180 _axis = axis;
181 _indices_rank = indices->info()->num_dimensions();
182
183 if (_axis < 0)
184 {
185 _axis += input->info()->num_dimensions();
186 }
187 ARM_COMPUTE_ERROR_ON(0 > _axis || _axis >= static_cast<int32_t>(input->info()->num_dimensions()));
188
189 if (0 == _axis)
190 {
191 switch (_indices->info()->data_type())
192 {
193 case DataType::U32:
194 _func = &NEGatherKernelEx::gather_0_axis<uint32_t>;
195 break;
196 case DataType::S32:
197 _func = &NEGatherKernelEx::gather_0_axis<int32_t>;
198 break;
199 default:
200 ARM_COMPUTE_ERROR("Not supported");
201 break;
202 }
203 }
204 else
205 {
206 switch (_indices->info()->data_type())
207 {
208 case DataType::U32:
209 _func = &NEGatherKernelEx::gather_n_axis<uint32_t>;
210 break;
211 case DataType::S32:
212 _func = &NEGatherKernelEx::gather_n_axis<int32_t>;
213 break;
214 default:
215 ARM_COMPUTE_ERROR("Not supported");
216 break;
217 }
218 }
219 // Output auto initialization if not yet initialized
221 input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis);
222 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type());
223
224 // Create window
225 Window win = calculate_max_window(*output->info(), Steps());
226 output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
227
228 INEKernel::configure(win);
229}
const luci_interpreter::RuntimeShape output_shape
::nncc::core::ADT::tensor::Shape TensorShape
Definition TensorShape.h:25
TensorShape compute_gather_shape_ex(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis)

References arm_compute::misc::shape_calculator::compute_gather_shape_ex(), and output_shape.

◆ name()

const char * arm_compute::NEGatherKernelEx::name ( ) const
inlineoverride

Name of the kernel

Returns
Kernel name

Definition at line 72 of file NEGatherKernelEx.h.

72{ return "NEGatherKernelEx"; }

◆ operator=() [1/2]

NEGatherKernelEx & arm_compute::NEGatherKernelEx::operator= ( const NEGatherKernelEx )
delete

Prevent instances of this class from being copied (As this class contains pointers).

◆ operator=() [2/2]

NEGatherKernelEx & arm_compute::NEGatherKernelEx::operator= ( NEGatherKernelEx &&  )
default

Allow instances of this class to be moved.

◆ run()

void arm_compute::NEGatherKernelEx::run ( const Window &  window,
const ThreadInfo &  info 
)
override

Definition at line 264 of file NEGatherKernelEx.cpp.

265{
266 ARM_COMPUTE_UNUSED(info);
267 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
268 ARM_COMPUTE_ERROR_ON(_func == nullptr);
269
270 (this->*_func)(window, info);
271}
volatile const char info[]

References info.

Referenced by package.infer.session::inference().

◆ validate()

Status arm_compute::NEGatherKernelEx::validate ( const ITensorInfo *  input,
const ITensorInfo *  indices,
const ITensorInfo *  output,
int  axis 
)
static

Static function to check if given info will lead to a valid configuration of NEGatherKernelEx

Parameters
[in]inputSource tensor info. Supported tensor rank: up to 4. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
[in]indicesIndices tensor info. Supported tensor rank: up to 3. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[axis])
[in]outputDestination tensor info. Data type supported: Same as input
[in]axis(Optional) The axis in input to gather indices from. Negative values wrap around. Defaults to 0
Returns
a status

Definition at line 231 of file NEGatherKernelEx.cpp.

233{
234 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output);
235 ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 3);
236 ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
237 ARM_COMPUTE_ERROR_ON(input->num_dimensions() + indices->num_dimensions() - 1 > 4);
238
239 if (axis < 0)
240 {
241 axis += input->num_dimensions();
242 }
243
244 ARM_COMPUTE_RETURN_ERROR_ON(0 > axis || axis >= static_cast<int32_t>(input->num_dimensions()));
245 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
246 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(
247 input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16,
248 DataType::U32, DataType::S32, DataType::F16, DataType::F32);
249
250 if (output->total_size() != 0)
251 {
252 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
253 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
255 input->tensor_shape(), indices->tensor_shape(), axis);
256 ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() != output->tensor_shape().total_size());
257 }
258
259 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32, DataType::S32);
260
261 return Status{};
262}

References arm_compute::misc::shape_calculator::compute_gather_shape_ex(), and output_shape.

Referenced by arm_compute::NEGatherEx::validate().


The documentation for this class was generated from the following files: