ONE - On-device Neural Engine
Loading...
Searching...
No Matches
arm_compute::NEOneHotKernel Class Reference

#include <NEOneHotKernel.h>

Collaboration diagram for arm_compute::NEOneHotKernel:

Public Member Functions

 NEOneHotKernel ()
 
 NEOneHotKernel (const NEOneHotKernel &)=delete
 
NEOneHotKerneloperator= (const NEOneHotKernel &)=delete
 
 NEOneHotKernel (NEOneHotKernel &&)=default
 
NEOneHotKerneloperator= (NEOneHotKernel &&)=default
 
 ~NEOneHotKernel ()=default
 
const char * name () const override
 
void configure (const ITensor *indices, const ITensor *depth, const ITensor *on_value, const ITensor *off_value, ITensor *output, int axis=-1)
 
void run (const Window &window, const ThreadInfo &info) override
 

Static Public Member Functions

static Status validate (const ITensorInfo *indices, const ITensorInfo *depth, const ITensorInfo *on_value, const ITensorInfo *off_value, const ITensorInfo *output, int axis=-1)
 

Detailed Description

Kernel to perform other operation on NEON

Definition at line 49 of file NEOneHotKernel.h.

Constructor & Destructor Documentation

◆ NEOneHotKernel() [1/3]

arm_compute::NEOneHotKernel::NEOneHotKernel ( )

Default constructor.

Definition at line 107 of file NEOneHotKernel.cpp.

108 : _indices{nullptr}, _depth{nullptr}, _on_value{nullptr}, _off_value{nullptr}, _axis{-1},
109 _output{nullptr}, _func{}
110{
111}

◆ NEOneHotKernel() [2/3]

arm_compute::NEOneHotKernel::NEOneHotKernel ( const NEOneHotKernel )
delete

Prevent instances of this class from being copied (As this class contains pointers).

◆ NEOneHotKernel() [3/3]

arm_compute::NEOneHotKernel::NEOneHotKernel ( NEOneHotKernel &&  )
default

Allow instances of this class to be moved.

◆ ~NEOneHotKernel()

arm_compute::NEOneHotKernel::~NEOneHotKernel ( )
default

Default detructor

Member Function Documentation

◆ configure()

void arm_compute::NEOneHotKernel::configure ( const ITensor *  indices,
const ITensor *  depth,
const ITensor *  on_value,
const ITensor *  off_value,
ITensor *  output,
int  axis = -1 
)

Initialise the kernel's inputs and outputs

Parameters
[in]indicesIndices tensor. Supported tensor rank: up to 3. Must be one of the following types: U32/S32
[in]depthThe tensor for depth of the one hot dimension. Supported tensor rank: up to 3. Must be one of the following types: U32/S32
[in]on_valueOn value tensor. Supported tensor rank: only 1. Data type supported: U8/S8/U16/S16/F16/U32/S32/F32
[in]off_valueOff value tensor. Supported tensor rank: only 1. Data type supported: Same as on_value
[out]outputDestination tensor. Data type supported: Same as on_value
[in]axis(Optional) The axis to fill. Negative values wrap around. Defaults to -1. The value must be in range [-indices.rank , indices.rank)

Definition at line 167 of file NEOneHotKernel.cpp.

170{
171 ARM_COMPUTE_ERROR_ON_NULLPTR(indices, depth, on_value, off_value, output);
172 ARM_COMPUTE_ERROR_ON(output->info()->total_size() == 0);
173 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(indices->info(), depth->info(), on_value->info(),
174 off_value->info(), output->info(), axis));
175 _indices = indices;
176 _depth = depth;
177 _on_value = on_value;
178 _off_value = off_value;
179 _output = output;
180 _axis = wrap_around(axis, static_cast<int>(output->info()->num_dimensions()));
181 if (0 == _axis)
182 {
183 switch (_indices->info()->data_type())
184 {
185 case DataType::U32:
186 _func = &NEOneHotKernel::onehot_0_axis<uint32_t>;
187 break;
188 case DataType::S32:
189 _func = &NEOneHotKernel::onehot_0_axis<int32_t>;
190 break;
191 default:
192 ARM_COMPUTE_ERROR("Not supported");
193 break;
194 }
195 }
196 else
197 {
198 switch (_indices->info()->data_type())
199 {
200 case DataType::U32:
201 _func = &NEOneHotKernel::onehot_n_axis<uint32_t>;
202 break;
203 case DataType::S32:
204 _func = &NEOneHotKernel::onehot_n_axis<int32_t>;
205 break;
206 default:
207 ARM_COMPUTE_ERROR("Not supported");
208 break;
209 }
210 }
211 // Create window
212 Window win = calculate_max_window(*output->info(), Steps());
213 output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
214 INEKernel::configure(win);
215}

◆ name()

const char * arm_compute::NEOneHotKernel::name ( ) const
inlineoverride

Name of the kernel

Returns
Kernel name

Definition at line 68 of file NEOneHotKernel.h.

68{ return "NEOneHotKernel"; }

◆ operator=() [1/2]

NEOneHotKernel & arm_compute::NEOneHotKernel::operator= ( const NEOneHotKernel )
delete

Prevent instances of this class from being copied (As this class contains pointers).

◆ operator=() [2/2]

NEOneHotKernel & arm_compute::NEOneHotKernel::operator= ( NEOneHotKernel &&  )
default

Allow instances of this class to be moved.

◆ run()

void arm_compute::NEOneHotKernel::run ( const Window &  window,
const ThreadInfo &  info 
)
override

Definition at line 226 of file NEOneHotKernel.cpp.

227{
228 ARM_COMPUTE_UNUSED(info);
229 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
230 ARM_COMPUTE_ERROR_ON(_func == nullptr);
231 (this->*_func)(window, info);
232}
volatile const char info[]

References info.

Referenced by package.infer.session::inference().

◆ validate()

Status arm_compute::NEOneHotKernel::validate ( const ITensorInfo *  indices,
const ITensorInfo *  depth,
const ITensorInfo *  on_value,
const ITensorInfo *  off_value,
const ITensorInfo *  output,
int  axis = -1 
)
static

Static function to check if given info will lead to a valid configuration of NEOneHotKernel

Parameters
[in]indicesIndices tensor info. Supported tensor rank: up to 3. Must be one of the following types: U32/S32
[in]depthThe tensor info for depth of the one hot dimension. Supported tensor rank: up to 3. Must be one of the following types: U32/S32
[in]on_valueOn value tensor info. Supported tensor rank: only 1. Data type supported: U8/S8/U16/S16/F16/U32/S32/F32
[in]off_valueOff value tensor info. Supported tensor rank: only 1. Data type supported: Same as on_value
[out]outputDestination tensor info. Data type supported: Same as on_value
[in]axis(Optional) The axis to fill. Negative values wrap around. Defaults to -1. The value must be in range [-indices.rank , indices.rank)
Returns
a status

Definition at line 217 of file NEOneHotKernel.cpp.

220{
221 ARM_COMPUTE_RETURN_ON_ERROR(
222 validate_arguments(indices, depth, on_value, off_value, output, axis));
223 return Status{};
224}

Referenced by arm_compute::NEOneHot::validate().


The documentation for this class was generated from the following files: