ONE - On-device Neural Engine
Loading...
Searching...
No Matches
arm_compute::CLOneHotKernel Class Reference

#include <CLOneHotKernel.h>

Collaboration diagram for arm_compute::CLOneHotKernel:

Public Member Functions

 CLOneHotKernel ()
 
 CLOneHotKernel (const CLOneHotKernel &)=delete
 
CLOneHotKerneloperator= (const CLOneHotKernel &)=delete
 
 CLOneHotKernel (CLOneHotKernel &&)=default
 
CLOneHotKerneloperator= (CLOneHotKernel &&)=default
 
 ~CLOneHotKernel ()=default
 
void configure (const ICLTensor *indices, const ICLTensor *on_value, const ICLTensor *off_value, ICLTensor *output, int depth, int axis=-1)
 
void configure (const ICLTensor *indices, const ICLTensor *on_value, ICLTensor *output, int depth, int axis=-1)
 
void run (const Window &window, cl::CommandQueue &queue) override
 

Static Public Member Functions

static Status validate (const ITensorInfo *indices, const ITensorInfo *on_value, const ITensorInfo *off_value, const ITensorInfo *output, int depth, int axis=-1)
 
static Status validate (const ITensorInfo *indices, const ITensorInfo *on_value, const ITensorInfo *output, int depth, int axis=-1)
 

Detailed Description

Interface for the kernel to perform one-hot encoding

Definition at line 48 of file CLOneHotKernel.h.

Constructor & Destructor Documentation

◆ CLOneHotKernel() [1/3]

arm_compute::CLOneHotKernel::CLOneHotKernel ( )

Default constructor

Definition at line 94 of file CLOneHotKernel.cpp.

95 : _indices(nullptr), _on_value(nullptr), _off_value(nullptr), _output(nullptr),
96 _is_off_value_memset(false)
97{
98}

◆ CLOneHotKernel() [2/3]

arm_compute::CLOneHotKernel::CLOneHotKernel ( const CLOneHotKernel )
delete

Prevent instances of this class from being copied (As this class contains pointers)

◆ CLOneHotKernel() [3/3]

arm_compute::CLOneHotKernel::CLOneHotKernel ( CLOneHotKernel &&  )
default

Allow instances of this class to be moved

◆ ~CLOneHotKernel()

arm_compute::CLOneHotKernel::~CLOneHotKernel ( )
default

Default destructor

References validate().

Member Function Documentation

◆ configure() [1/2]

void arm_compute::CLOneHotKernel::configure ( const ICLTensor *  indices,
const ICLTensor *  on_value,
const ICLTensor *  off_value,
ICLTensor *  output,
int  depth,
int  axis = -1 
)

Initialise the kernel's inputs and output

Parameters
[in]indicesIndices tensor. Supported tensor rank: up to 3. Must be one of the following types: U32/S32
[in]on_valueOn value tensor. Supported tensor rank: only 1. Data type supported: U8/S8/U16/S16/F16/U32/S32/F32
[in]off_valueOff value tensor. Supported tensor rank: only 1. Data type supported: Same as on_value
[out]outputDestination tensor. Data type supported: Same as on_value
[in]depthThe depth of the one hot dimension.
[in]axis(Optional) The axis to fill. Negative values wrap around. Defaults to -1. value must be in range [-indices.rank , indices.rank)

Definition at line 99 of file CLOneHotKernel.cpp.

101{
102 _is_off_value_memset = false;
103 ARM_COMPUTE_ERROR_ON_NULLPTR(indices, on_value, off_value, output);
104 ARM_COMPUTE_ERROR_ON_NULLPTR(off_value->info());
105 ARM_COMPUTE_ERROR_ON(off_value->info()->tensor_shape().total_size() != 1);
106 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, off_value);
107 _off_value = off_value;
108 configure_common(indices, on_value, output, depth, axis);
109}

Referenced by arm_compute::CLOneHot::configure(), and arm_compute::CLOneHot::configure().

◆ configure() [2/2]

void arm_compute::CLOneHotKernel::configure ( const ICLTensor *  indices,
const ICLTensor *  on_value,
ICLTensor *  output,
int  depth,
int  axis = -1 
)

Initialise the kernel's inputs and output already initialized to off_value

Parameters
[in]indicesIndices tensor. Supported tensor rank: up to 3. Must be one of the following types: U32/S32
[in]on_valueOn value tensor. Supported tensor rank: only 1. Data type supported: U8/S8/U16/S16/F16/U32/S32/F32
[out]outputDestination tensor. Data type supported: Same as on_value
[in]depthThe depth of the one hot dimension.
[in]axis(Optional) The axis to fill. Negative values wrap around. Defaults to -1. value must be in range [-indices.rank , indices.rank)

Definition at line 110 of file CLOneHotKernel.cpp.

112{
113 _is_off_value_memset = true;
114 ARM_COMPUTE_ERROR_ON_NULLPTR(indices, on_value, output);
115 configure_common(indices, on_value, output, depth, axis);
116}

◆ operator=() [1/2]

CLOneHotKernel & arm_compute::CLOneHotKernel::operator= ( CLOneHotKernel &&  )
default

Allow instances of this class to be moved

◆ operator=() [2/2]

CLOneHotKernel & arm_compute::CLOneHotKernel::operator= ( const CLOneHotKernel )
delete

Prevent instances of this class from being copied (As this class contains pointers)

◆ run()

void arm_compute::CLOneHotKernel::run ( const Window &  window,
cl::CommandQueue &  queue 
)
override

Definition at line 173 of file CLOneHotKernel.cpp.

174{
175 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
176 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
177 Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
178 unsigned int idx = 0;
179 add_3D_tensor_argument(idx, _indices, window_collapsed);
180 add_1D_tensor_argument(idx, _on_value, window_collapsed);
181 if (!_is_off_value_memset)
182 {
183 add_1D_tensor_argument(idx, _off_value, window_collapsed);
184 }
185 add_4D_tensor_argument(idx, _output, window_collapsed);
186 enqueue(queue, *this, window_collapsed, lws_hint());
187}

Referenced by package.infer.session::inference().

◆ validate() [1/2]

Status arm_compute::CLOneHotKernel::validate ( const ITensorInfo *  indices,
const ITensorInfo *  on_value,
const ITensorInfo *  off_value,
const ITensorInfo *  output,
int  depth,
int  axis = -1 
)
static

Static function to check if given info will lead to a valid configuration of CLOneHotKernel

Parameters
[in]indicesIndices tensor. Supported tensor rank: up to 3. Must be one of the following types: U32/S32
[in]on_valueOn value tensor. Supported tensor rank: only 1. Data type supported: U8/S8/U16/S16/F16/U32/S32/F32
[in]off_valueOff value tensor. Supported tensor rank: only 1. Data type supported: Same as on_value
[in]outputDestination tensor. Data type supported: Same as on_value
[in]depthThe depth of the one hot dimension.
[in]axis(Optional) The axis to fill. Negative values wrap around. Defaults to -1. value must be in range [-indices.rank , indices.rank)
Returns
a status

Definition at line 149 of file CLOneHotKernel.cpp.

152{
153 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(off_value);
154 ARM_COMPUTE_RETURN_ERROR_ON(off_value->tensor_shape().total_size() != 1);
155 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(on_value, off_value);
156 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(indices, on_value, output, depth, axis));
157 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(indices->clone().get(),
158 on_value->clone().get(),
159 output->clone().get(), depth, axis)
160 .first);
161 return Status{};
162}

Referenced by arm_compute::CLOneHot::validate(), and ~CLOneHotKernel().

◆ validate() [2/2]

Status arm_compute::CLOneHotKernel::validate ( const ITensorInfo *  indices,
const ITensorInfo *  on_value,
const ITensorInfo *  output,
int  depth,
int  axis = -1 
)
static

Static function to check if given info will lead to a valid configuration of CLOneHotKernel without off_value

Parameters
[in]indicesIndices tensor. Supported tensor rank: up to 3. Must be one of the following types: U32/S32
[in]on_valueOn value tensor. Supported tensor rank: only 1. Data type supported: U8/S8/U16/S16/F16/U32/S32/F32
[in]outputDestination tensor. Data type supported: Same as on_value
[in]depthThe depth of the one hot dimension.
[in]axis(Optional) The axis to fill. Negative values wrap around. Defaults to -1. value must be in range [-indices.rank , indices.rank)
Returns
a status

Definition at line 163 of file CLOneHotKernel.cpp.

165{
166 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(indices, on_value, output, depth, axis));
167 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(indices->clone().get(),
168 on_value->clone().get(),
169 output->clone().get(), depth, axis)
170 .first);
171 return Status{};
172}

The documentation for this class was generated from the following files: