ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
arm_compute::NECastBoolKernel Class Reference

Class for the kernel converting boolean type. More...

#include <NECastBoolKernel.h>

Collaboration diagram for arm_compute::NECastBoolKernel:

Public Member Functions

const char * name () const override
 
 NECastBoolKernel ()
 
 NECastBoolKernel (const NECastBoolKernel &)=delete
 
 NECastBoolKernel (NECastBoolKernel &&)=default
 
NECastBoolKerneloperator= (const NECastBoolKernel &)=delete
 
NECastBoolKerneloperator= (NECastBoolKernel &&)=default
 
void configure (const ITensor *input, ITensor *output)
 
void run (const Window &window, const ThreadInfo &info) override
 

Static Public Member Functions

static Status validate (const ITensorInfo *input, const ITensorInfo *output)
 

Detailed Description

Class for the kernel converting boolean type.

Definition at line 52 of file NECastBoolKernel.h.

Constructor & Destructor Documentation

◆ NECastBoolKernel() [1/3]

NECastBoolKernel::NECastBoolKernel ( )

Default constructor

Definition at line 79 of file NECastBoolKernel.cpp.

79: _input(nullptr), _output(nullptr) {}

◆ NECastBoolKernel() [2/3]

arm_compute::NECastBoolKernel::NECastBoolKernel ( const NECastBoolKernel )
delete

Prevent instances of this class from being copied (As this class contains pointers)

◆ NECastBoolKernel() [3/3]

arm_compute::NECastBoolKernel::NECastBoolKernel ( NECastBoolKernel &&  )
default

Default move constructor

Member Function Documentation

◆ configure()

void NECastBoolKernel::configure ( const ITensor *  input,
ITensor *  output 
)

Set the input and output of the kernel

Valid conversions Input -> Output :

  • U8 -> U8, S8, U16, S16, U32, S32, F32, F16
Parameters
[in]inputThe input tensor to convert. Data types supported: U8
[out]outputThe output tensor. Data types supported: U8/S8/U16/S16/U32/S32/F16/F32.

Definition at line 81 of file NECastBoolKernel.cpp.

82{
83 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
84
85 // Auto initialize output shape if not initialized (We can only auto-configure the shape, datatype
86 // must be given)
87 set_shape_if_empty(*output->info(), input->info()->tensor_shape());
88
89 _input = input;
90 _output = output;
91
92 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info()));
93
94 // Configure kernel window
95 Window win = calculate_max_window(*input->info(), Steps());
96 Coordinates coord;
97 coord.set_num_dimensions(output->info()->num_dimensions());
98 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
99
100 ICPPKernel::configure(win);
101}

◆ name()

const char * arm_compute::NECastBoolKernel::name ( ) const
inlineoverride

Definition at line 55 of file NECastBoolKernel.h.

55{ return "NECastBoolKernel"; }

◆ operator=() [1/2]

NECastBoolKernel & arm_compute::NECastBoolKernel::operator= ( const NECastBoolKernel )
delete

Prevent instances of this class from being copied (As this class contains pointers)

◆ operator=() [2/2]

NECastBoolKernel & arm_compute::NECastBoolKernel::operator= ( NECastBoolKernel &&  )
default

Default move assignment operator

References validate().

◆ run()

void NECastBoolKernel::run ( const Window &  window,
const ThreadInfo &  info 
)
override

Definition at line 109 of file NECastBoolKernel.cpp.

110{
111 ARM_COMPUTE_UNUSED(info);
112 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
113 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
114 ARM_COMPUTE_ERROR_ON_NULLPTR(_input, _output);
115 ARM_COMPUTE_ERROR_ON(_input == _output);
116
117 const auto window_start_x = static_cast<int>(window.x().start());
118 const auto window_end_x = static_cast<int>(window.x().end());
119 const int window_step_x = 16;
120
121 Window win{window};
122 win.set(Window::DimX, Window::Dimension(0, 1, 1));
123
124 Iterator input(_input, win);
125 Iterator output(_output, win);
126
127 const uint8_t true_val = 1;
128 const uint8x8_t mask_bool = vdup_n_u8(true_val);
129
130 switch (_output->info()->data_type())
131 {
132 case DataType::S8:
133 {
134 /* Conversion U8 -> S8 */
135 execute_window_loop(
136 win,
137 [&](const Coordinates &) {
138 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
139 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
140
141 int x = window_start_x;
142 for (; x <= (window_end_x - window_step_x); x += window_step_x)
143 {
144 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
145
146 vst1q_s8(output_ptr + x,
147 vreinterpretq_s8_u8(vandq_u8(texels_u8, vdupq_n_u8(true_val))));
148 }
149
150 // Compute left-over elements
151 for (; x < window_end_x; ++x)
152 {
153 *(output_ptr + x) = static_cast<int8_t>(*(input_ptr + x) & true_val);
154 }
155 },
156 input, output);
157 break;
158 }
159 case DataType::S16:
160 {
161 /* Up-conversion U8 -> S16 */
162 execute_window_loop(
163 win,
164 [&](const Coordinates &) {
165 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
166 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
167
168 int x = window_start_x;
169 for (; x <= (window_end_x - window_step_x); x += window_step_x)
170 {
171 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
172
173 const int16x8x2_t texels = {
174 {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
175 vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
176
177 vst1q_s16(output_ptr + x, texels.val[0]);
178 vst1q_s16(output_ptr + x + 8, texels.val[1]);
179 }
180
181 // Compute left-over elements
182 for (; x < window_end_x; ++x)
183 {
184 *(output_ptr + x) = static_cast<int32_t>(*(input_ptr + x) & true_val);
185 }
186 },
187 input, output);
188 break;
189 }
190 case DataType::S32:
191 {
192 /* Up-conversion U8 -> S32 */
193 execute_window_loop(
194 win,
195 [&](const Coordinates &) {
196 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
197 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
198
199 int x = window_start_x;
200 for (; x <= (window_end_x - window_step_x); x += window_step_x)
201 {
202 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
203
204 const int16x8x2_t texels = {
205 {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
206 vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
207
208 vst1q_s32(output_ptr + x, vmovl_s16(vget_low_s16(texels.val[0])));
209 vst1q_s32(output_ptr + x + 4, vmovl_s16(vget_high_s16(texels.val[0])));
210 vst1q_s32(output_ptr + x + 8, vmovl_s16(vget_low_s16(texels.val[1])));
211 vst1q_s32(output_ptr + x + 12, vmovl_s16(vget_high_s16(texels.val[1])));
212 }
213
214 // Compute left-over elements
215 for (; x < window_end_x; ++x)
216 {
217 *(output_ptr + x) = static_cast<uint32_t>(*(input_ptr + x) & true_val);
218 }
219 },
220 input, output);
221 break;
222 }
223 case DataType::F32:
224 {
225 /* Up-conversion U8 -> F32 */
226 execute_window_loop(
227 win,
228 [&](const Coordinates &) {
229 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
230 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
231
232 int x = window_start_x;
233 for (; x <= (window_end_x - window_step_x); x += window_step_x)
234 {
235 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
236
237 const int16x8x2_t texels = {
238 {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
239 vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
240 vst1q_f32(output_ptr + x, vcvtq_f32_s32(vmovl_s16(vget_low_s16(texels.val[0]))));
241 vst1q_f32(output_ptr + x + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(texels.val[0]))));
242 vst1q_f32(output_ptr + x + 8, vcvtq_f32_s32(vmovl_s16(vget_low_s16(texels.val[1]))));
243 vst1q_f32(output_ptr + x + 12, vcvtq_f32_s32(vmovl_s16(vget_high_s16(texels.val[1]))));
244 }
245
246 // Compute left-over elements
247 for (; x < window_end_x; ++x)
248 {
249 auto in = static_cast<uint32_t>(*(input_ptr + x) & true_val);
250 *(output_ptr + x) = static_cast<float>(in);
251 }
252 },
253 input, output);
254 break;
255 }
256#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
257 case DataType::F16:
258 {
259 /* Up-conversion U8 -> F16 */
260 execute_window_loop(
261 win,
262 [&](const Coordinates &) {
263 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
264 const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr());
265
266 int x = window_start_x;
267 for (; x <= (window_end_x - window_step_x); x += window_step_x)
268 {
269 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
270
271 const int16x8x2_t texels = {
272 {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
273 vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
274 vst1q_f16(output_ptr + x, vcvtq_f16_s16(texels.val[0]));
275 vst1q_f16(output_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
276 }
277
278 // Compute left-over elements
279 for (; x < window_end_x; ++x)
280 {
281 *(output_ptr + x) = static_cast<float16_t>(*(input_ptr + x) & true_val);
282 }
283 },
284 input, output);
285 break;
286 }
287#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
288 case DataType::U8:
289 {
290 /* Conversion U8 -> S8 */
291 execute_window_loop(
292 win,
293 [&](const Coordinates &) {
294 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
295 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
296
297 int x = window_start_x;
298 for (; x <= (window_end_x - window_step_x); x += window_step_x)
299 {
300 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
301
302 vst1q_u8(output_ptr + x, vandq_u8(texels_u8, vdupq_n_u8(true_val)));
303 }
304
305 // Compute left-over elements
306 for (; x < window_end_x; ++x)
307 {
308 *(output_ptr + x) = static_cast<uint8_t>(*(input_ptr + x) & true_val);
309 }
310 },
311 input, output);
312 break;
313 }
314 case DataType::U16:
315 {
316 /* Up-conversion U8 -> U16 */
317 execute_window_loop(
318 win,
319 [&](const Coordinates &) {
320 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
321 const auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr());
322
323 int x = window_start_x;
324 for (; x <= (window_end_x - window_step_x); x += window_step_x)
325 {
326 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
327
328 const uint16x8x2_t texels = {{vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool)),
329 vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool))}};
330
331 vst1q_u16(output_ptr + x, texels.val[0]);
332 vst1q_u16(output_ptr + x + 8, texels.val[1]);
333 }
334
335 // Compute left-over elements
336 for (; x < window_end_x; ++x)
337 {
338 *(output_ptr + x) = static_cast<uint16_t>(*(input_ptr + x) & true_val);
339 }
340 },
341 input, output);
342 break;
343 }
344 default:
345 ARM_COMPUTE_ERROR("Output data type not supported");
346 }
347}
volatile const char info[]

References info.

◆ validate()

Status NECastBoolKernel::validate ( const ITensorInfo *  input,
const ITensorInfo *  output 
)
static

Static function to check if given info will lead to a valid configuration of NECastBoolKernel

Parameters
[in]inputSource tensor info. Data types supported: U8
[in]outputDestination tensor info. Data type supported: U8/S8/U16/S16/U32/S32/F16/F32.
Returns
a status

Definition at line 103 of file NECastBoolKernel.cpp.

104{
105 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output));
106 return Status{};
107}

Referenced by operator=(), and arm_compute::NECastBool::validate().


The documentation for this class was generated from the following files: