ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
arm_compute::NEHashtableLookupKernel Class Reference

#include <NEHashtableLookupKernel.h>

Collaboration diagram for arm_compute::NEHashtableLookupKernel:

Public Member Functions

const char * name () const override
 
 NEHashtableLookupKernel ()
 
 NEHashtableLookupKernel (const NEHashtableLookupKernel &)=delete
 
NEHashtableLookupKerneloperator= (const NEHashtableLookupKernel &)=delete
 
 NEHashtableLookupKernel (NEHashtableLookupKernel &&)=default
 
NEHashtableLookupKerneloperator= (NEHashtableLookupKernel &&)=default
 
void configure (const ITensor *lookups, const ITensor *keys, const ITensor *input, ITensor *output, ITensor *hits)
 
void run (const Window &window, const ThreadInfo &info) override
 

Static Public Member Functions

static Status validate (const ITensorInfo *lookups, const ITensorInfo *keys, const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *hits)
 

Detailed Description

NEON kernel to perform HashtableLookup operation

Definition at line 52 of file NEHashtableLookupKernel.h.

Constructor & Destructor Documentation

◆ NEHashtableLookupKernel() [1/3]

NEHashtableLookupKernel::NEHashtableLookupKernel ( )

Default constructor

Definition at line 62 of file NEHashtableLookupKernel.cpp.

63 : _lookups(nullptr), _keys(nullptr), _input(nullptr), _output(nullptr), _hits{nullptr}
64{
65}

◆ NEHashtableLookupKernel() [2/3]

arm_compute::NEHashtableLookupKernel::NEHashtableLookupKernel ( const NEHashtableLookupKernel )
delete

Prevent instances of this class from being copied (As this class contains pointers).

◆ NEHashtableLookupKernel() [3/3]

arm_compute::NEHashtableLookupKernel::NEHashtableLookupKernel ( NEHashtableLookupKernel &&  )
default

Allow instances of this class to be moved

Member Function Documentation

◆ configure()

void NEHashtableLookupKernel::configure ( const ITensor *  lookups,
const ITensor *  keys,
const ITensor *  input,
ITensor *  output,
ITensor *  hits 
)

Initialize the kernel's inputs, outputs.

Parameters
[in]lookupsLookups 1D tensor that values are indices into the first dimension of input. Data types supported: S32
[in]keysKeys 1D tensor. keys and input pair represent a map. Data types supported: S32
[in]inputSource tensor. Data types supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
[out]outputDestination tensor. Data types and data layouts supported: Same as input.
[out]hitsHits 1D tensor. A boolean tensor that indicates whether the lookup hits (True) or not (False). Data types supported: U8/QASYMM8 input.

Definition at line 67 of file NEHashtableLookupKernel.cpp.

69{
70 ARM_COMPUTE_ERROR_ON_NULLPTR(lookups, keys, input, output, hits);
71 ARM_COMPUTE_ERROR_THROW_ON(
72 validate(lookups->info(), keys->info(), input->info(), output->info(), hits->info()));
73
74 _lookups = lookups;
75 _keys = keys;
76 _input = input;
77 _output = output;
78 _hits = hits;
79
80 // Auto initialize output if not initialized
81 auto out_shape{input->info()->tensor_shape()};
82 out_shape.set(out_shape.num_dimensions() - 1, lookups->info()->num_dimensions(), false);
83 auto_init_if_empty(*output->info(), out_shape, 1, input->info()->data_type(),
84 input->info()->quantization_info());
85
86 // Auto initialize hits if not initialized
87 auto_init_if_empty(*hits->info(), lookups->info()->tensor_shape(), 1, DataType::U8);
88
89 INEKernel::configure(calculate_max_window(*output->info()));
90}
static Status validate(const ITensorInfo *lookups, const ITensorInfo *keys, const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *hits)

References validate().

◆ name()

const char * arm_compute::NEHashtableLookupKernel::name ( ) const
inlineoverride

Definition at line 55 of file NEHashtableLookupKernel.h.

55{ return "NEHashtableLookupKernel"; }

◆ operator=() [1/2]

NEHashtableLookupKernel & arm_compute::NEHashtableLookupKernel::operator= ( const NEHashtableLookupKernel )
delete

Prevent instances of this class from being copied (As this class contains pointers).

◆ operator=() [2/2]

NEHashtableLookupKernel & arm_compute::NEHashtableLookupKernel::operator= ( NEHashtableLookupKernel &&  )
default

Allow instances of this class to be moved

References validate().

◆ run()

void NEHashtableLookupKernel::run ( const Window &  window,
const ThreadInfo &  info 
)
override

Definition at line 132 of file NEHashtableLookupKernel.cpp.

133{
134 ARM_COMPUTE_UNUSED(info);
135 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
136 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
137
138 const size_t lookup_dim = _output->info()->num_dimensions() - 1;
139 const int const_0 = _output->info()->data_type() == DataType::QASYMM8
140 ? _output->info()->quantization_info().uniform().offset
141 : 0;
142
143 std::unordered_map<int32_t, size_t> key_index_map;
144 for (size_t n = 0; n < _keys->info()->dimension(0); ++n)
145 {
146 const int32_t key = *reinterpret_cast<int32_t *>(_keys->ptr_to_element({n}));
147 key_index_map[key] = n;
148 }
149 std::vector<size_t> lookup_indices;
150 for (size_t k = 0; k < _lookups->info()->dimension(0); ++k)
151 {
152 const int32_t key = *reinterpret_cast<int32_t *>(_lookups->ptr_to_element({k}));
153 const auto it = key_index_map.find(key);
154 if (it == key_index_map.end())
155 {
156 lookup_indices.emplace_back(NOT_HIT);
157 *_hits->ptr_to_element({k}) = 0;
158 }
159 else
160 {
161#if defined(ARM_COMPUTE_DEBUG_ENABLED)
162 if (it->second >= _keys->info()->dimension(0))
163 ARM_COMPUTE_ERROR("HashTable Lookup: Index out of bounds.");
164#endif // defined(ARM_COMPUTE_DEBUG_ENABLED)
165 lookup_indices.emplace_back(it->second);
166 *_hits->ptr_to_element({k}) = 1;
167 }
168 }
169
170 Window output_window{window};
171 output_window.set(Window::DimX,
172 Window::Dimension(output_window.x().start(), output_window.x().end(),
173 _input->info()->dimension(0)));
174
175 Window out_slice = output_window.first_slice_window_4D();
176 do
177 {
178 Iterator output_it(_output, out_slice);
179
180 execute_window_loop(
181 out_slice,
182 [&](const Coordinates &id) {
183 const auto lookup = lookup_indices.at(id[lookup_dim]);
184 if (lookup == NOT_HIT)
185 {
186 memset(output_it.ptr(), const_0,
187 _output->info()->dimension(0) * _output->info()->element_size());
188 }
189 else
190 {
191 Coordinates input_id{id};
192 input_id.set(lookup_dim, lookup);
193 memcpy(output_it.ptr(), _input->ptr_to_element(input_id),
194 _output->info()->dimension(0) * _output->info()->element_size());
195 }
196 },
197 output_it);
198
199 } while (window.slide_window_slice_4D(out_slice));
200}
volatile const char info[]

References info.

◆ validate()

Status NEHashtableLookupKernel::validate ( const ITensorInfo *  lookups,
const ITensorInfo *  keys,
const ITensorInfo *  input,
const ITensorInfo *  output,
const ITensorInfo *  hits 
)
static

Static function to check if given info will lead to a valid configuration of NEHashtableLookupKernel

Parameters
[in]lookupsThe lookups tensor info. Data types supported: S32.
[in]keysThe keys tensor info. keys and input pair represent a map. Data types supported: S32
[in]inputThe input tensor info. Data types supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
[out]outputThe output tensor info. Data types and data layouts supported: Same as input.
[out]hitsThe hits tensor info. A boolean tensor that indicates whether the lookup hits (True) or not (False). Data types supported: U8/QASYMM8
Returns
a status

Definition at line 92 of file NEHashtableLookupKernel.cpp.

95{
96 ARM_COMPUTE_ERROR_ON_NULLPTR(lookups, keys, input, output, hits);
97 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(
98 input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16,
99 DataType::U32, DataType::S32, DataType::F16, DataType::F32);
100 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lookups, 1, DataType::S32);
101 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(keys, 1, DataType::S32);
102
103 ARM_COMPUTE_ERROR_ON(input->num_dimensions() < 2 && input->num_dimensions() > 4);
104 ARM_COMPUTE_ERROR_ON(lookups->num_dimensions() > 1);
105 ARM_COMPUTE_ERROR_ON(keys->num_dimensions() > 1);
106 ARM_COMPUTE_ERROR_ON(keys->dimension(0) != input->dimension(input->num_dimensions() - 1));
107
108 // Validate in case of configured output
109 if (output->total_size() > 0)
110 {
111 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
112 ARM_COMPUTE_ERROR_ON(input->num_dimensions() != output->num_dimensions());
113 ARM_COMPUTE_ERROR_ON(output->dimension(output->num_dimensions() - 1) != lookups->dimension(0));
114 for (size_t i = 0; i < output->num_dimensions() - 1; ++i)
115 {
116 ARM_COMPUTE_ERROR_ON(input->dimension(i) != output->dimension(i));
117 }
118 }
119
120 // Validate in case of configured hits
121 if (hits->total_size() > 0)
122 {
123 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(hits, 1, DataType::U8, DataType::QASYMM8);
124 ARM_COMPUTE_ERROR_ON(hits->dimension(0) != output->dimension(output->num_dimensions() - 1));
125 ARM_COMPUTE_ERROR_ON(hits->dimension(0) != lookups->dimension(0));
126 ARM_COMPUTE_ERROR_ON(hits->num_dimensions() > 1);
127 }
128
129 return Status{};
130}

Referenced by configure(), operator=(), and arm_compute::NEHashtableLookup::validate().


The documentation for this class was generated from the following files: