ONE - On-device Neural Engine
Loading...
Searching...
No Matches
arm_compute::NEEmbeddingLookupKernel Class Reference

#include <NEEmbeddingLookupKernel.h>

Collaboration diagram for arm_compute::NEEmbeddingLookupKernel:

Public Member Functions

const char * name () const override
 
 NEEmbeddingLookupKernel ()
 
 NEEmbeddingLookupKernel (const NEEmbeddingLookupKernel &)=delete
 
NEEmbeddingLookupKerneloperator= (const NEEmbeddingLookupKernel &)=delete
 
 NEEmbeddingLookupKernel (NEEmbeddingLookupKernel &&)=default
 
NEEmbeddingLookupKerneloperator= (NEEmbeddingLookupKernel &&)=default
 
void configure (const ITensor *input, ITensor *output, const ITensor *lookups)
 
void run (const Window &window, const ThreadInfo &info) override
 

Static Public Member Functions

static Status validate (const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *lookups)
 

Detailed Description

NEON kernel to perform EmbeddingLookup operation

Definition at line 52 of file NEEmbeddingLookupKernel.h.

Constructor & Destructor Documentation

◆ NEEmbeddingLookupKernel() [1/3]

NEEmbeddingLookupKernel::NEEmbeddingLookupKernel ( )

Default constructor

Definition at line 55 of file NEEmbeddingLookupKernel.cpp.

56 : _input(nullptr), _lookups(nullptr), _output(nullptr)
57{
58}

◆ NEEmbeddingLookupKernel() [2/3]

arm_compute::NEEmbeddingLookupKernel::NEEmbeddingLookupKernel ( const NEEmbeddingLookupKernel )
delete

Prevent instances of this class from being copied (As this class contains pointers).

◆ NEEmbeddingLookupKernel() [3/3]

arm_compute::NEEmbeddingLookupKernel::NEEmbeddingLookupKernel ( NEEmbeddingLookupKernel &&  )
default

Allow instances of this class to be moved

Member Function Documentation

◆ configure()

void NEEmbeddingLookupKernel::configure ( const ITensor *  input,
ITensor *  output,
const ITensor *  lookups 
)

Initialize the kernel's input, output.

Parameters
[in]inputSource tensor. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32.
[out]outputDestination tensor. Data types supported: same as input.
[in]lookupsLookups are 1D tensor that values are indices into the first dimension of input.

Definition at line 60 of file NEEmbeddingLookupKernel.cpp.

62{
63 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
64 ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), output->info(), lookups->info()));
65
66 _input = input;
67 _output = output;
68 _lookups = lookups;
69
70 // Auto initialize output if not initialized
71 auto out_shape = input->info()->tensor_shape();
72 out_shape.set(out_shape.num_dimensions() - 1, lookups->info()->num_dimensions());
73 auto_init_if_empty(*output->info(), out_shape, 1, input->info()->data_type(),
74 input->info()->quantization_info());
75
76 INEKernel::configure(calculate_max_window(*output->info()));
77}
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *lookups)

References validate().

◆ name()

const char * arm_compute::NEEmbeddingLookupKernel::name ( ) const
inlineoverride

Definition at line 55 of file NEEmbeddingLookupKernel.h.

55{ return "NEEmbeddingLookupKernel"; }

◆ operator=() [1/2]

NEEmbeddingLookupKernel & arm_compute::NEEmbeddingLookupKernel::operator= ( const NEEmbeddingLookupKernel )
delete

Prevent instances of this class from being copied (As this class contains pointers).

◆ operator=() [2/2]

NEEmbeddingLookupKernel & arm_compute::NEEmbeddingLookupKernel::operator= ( NEEmbeddingLookupKernel &&  )
default

Allow instances of this class to be moved

References validate().

◆ run()

void NEEmbeddingLookupKernel::run ( const Window &  window,
const ThreadInfo &  info 
)
override

Definition at line 107 of file NEEmbeddingLookupKernel.cpp.

108{
109 ARM_COMPUTE_UNUSED(info);
110 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
111 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
112
113 const size_t lookup_dim = _output->info()->num_dimensions() - 1;
114
115 Window output_window{window};
116 output_window.set(Window::DimX,
117 Window::Dimension(output_window.x().start(), output_window.x().end(),
118 _input->info()->dimension(0)));
119
120 Window out_slice = output_window.first_slice_window_4D();
121 do
122 {
123 Iterator output_it(_output, out_slice);
124
125 execute_window_loop(
126 out_slice,
127 [&](const Coordinates &id) {
128 const int32_t lookup =
129 *reinterpret_cast<int32_t *>(_lookups->ptr_to_element(Coordinates{id[lookup_dim]}));
130 Coordinates input_id{id};
131 input_id.set(lookup_dim, lookup);
132 memcpy(output_it.ptr(), _input->ptr_to_element(input_id),
133 _output->info()->dimension(0) * _output->info()->element_size());
134 },
135 output_it);
136
137 } while (window.slide_window_slice_4D(out_slice));
138}
volatile const char info[]

References info.

Referenced by package.infer.session::inference().

◆ validate()

Status NEEmbeddingLookupKernel::validate ( const ITensorInfo *  input,
const ITensorInfo *  output,
const ITensorInfo *  lookups 
)
static

Static function to check if given info will lead to a valid configuration of NEEmbeddingLookupKernel

Parameters
[in]inputSource tensor. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32.
[in]outputDestination tensor. Data types supported: same as input.
[in]lookupsLookups info. Data types supported: S32.
Returns
a status

Definition at line 79 of file NEEmbeddingLookupKernel.cpp.

82{
83 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, lookups);
84 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(
85 input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16,
86 DataType::U32, DataType::S32, DataType::F16, DataType::F32);
87 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lookups, 1, DataType::S32);
88
89 ARM_COMPUTE_ERROR_ON(input->num_dimensions() < 2 && input->num_dimensions() > 4);
90 ARM_COMPUTE_ERROR_ON(lookups->num_dimensions() > 1);
91
92 // Validate in case of configured output
93 if (output->total_size() > 0)
94 {
95 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
96 ARM_COMPUTE_ERROR_ON(input->num_dimensions() != output->num_dimensions());
97 ARM_COMPUTE_ERROR_ON(output->dimension(output->num_dimensions() - 1) != lookups->dimension(0));
98 for (size_t i = 0; i < output->num_dimensions() - 1; ++i)
99 {
100 ARM_COMPUTE_ERROR_ON(input->dimension(i) != output->dimension(i));
101 }
102 }
103
104 return Status{};
105}

Referenced by configure(), and operator=().


The documentation for this class was generated from the following files: