18#ifndef __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
19#define __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
29#include "fixedpoint/fixedpoint.h"
39template <
typename ElementwiseF,
typename ScalarBroadcastF,
typename T>
42 const T *unswitched_input1_data,
44 const T *unswitched_input2_data,
45 const Shape & , T *output_data,
46 ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
48 const T *input1_data = switch_inputs ? unswitched_input2_data : unswitched_input1_data;
49 const T *input2_data = switch_inputs ? unswitched_input1_data : unswitched_input2_data;
55 T *output_data_ptr = output_data;
56 const T *input1_data_ptr = input1_data;
57 const T *input2_data_reset = input2_data;
73 for (
int i0 = 0; i0 < y0; ++i0)
75 const T *input2_data_ptr =
nullptr;
76 for (
int i1 = 0; i1 < y1; ++i1)
78 input2_data_ptr = input2_data_reset;
79 for (
int i2 = 0; i2 < y2; ++i2)
81 for (
int i3 = 0; i3 < y3; ++i3)
83 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
84 input2_data_ptr += y4;
85 output_data_ptr += y4;
88 input1_data_ptr += y4;
92 input2_data_reset = input2_data_ptr;
107 for (
int i0 = 0; i0 < y0; ++i0)
109 const T *input2_data_ptr =
nullptr;
110 for (
int i1 = 0; i1 < y1; ++i1)
112 input2_data_ptr = input2_data_reset;
113 for (
int i2 = 0; i2 < y2; ++i2)
115 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
116 input2_data_ptr += y3;
117 output_data_ptr += y3;
118 input1_data_ptr += 1;
121 input2_data_reset = input2_data_ptr;
127template <
typename ElementwiseF,
typename ScalarBroadcastF,
typename T>
130 const T *unswitched_input1_data,
132 const T *unswitched_input2_data,
133 const Shape & , T *output_data,
134 ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
144 const bool use_unswitched =
148 const T *input1_data = use_unswitched ? unswitched_input1_data : unswitched_input2_data;
149 const T *input2_data = use_unswitched ? unswitched_input2_data : unswitched_input1_data;
155 T *output_data_ptr = output_data;
156 const T *input1_data_ptr = input1_data;
157 const T *input2_data_reset = input2_data;
172 for (
int i0 = 0; i0 < y0; ++i0)
174 const T *input2_data_ptr =
nullptr;
175 for (
int i1 = 0; i1 < y1; ++i1)
177 input2_data_ptr = input2_data_reset;
178 for (
int i2 = 0; i2 < y2; ++i2)
180 for (
int i3 = 0; i3 < y3; ++i3)
182 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
183 input2_data_ptr += y4;
184 output_data_ptr += y4;
187 input1_data_ptr += y4;
191 input2_data_reset = input2_data_ptr;
206 for (
int i0 = 0; i0 < y0; ++i0)
208 const T *input2_data_ptr =
nullptr;
209 for (
int i1 = 0; i1 < y1; ++i1)
211 input2_data_ptr = input2_data_reset;
212 for (
int i2 = 0; i2 < y2; ++i2)
214 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
215 input2_data_ptr += y3;
216 output_data_ptr += y3;
217 input1_data_ptr += 1;
220 input2_data_reset = input2_data_ptr;
226inline typename std::enable_if_t<is_quant8<T>::value, int32_t>
229 const int32_t input1_val = params.
input1_offset + input1_data;
230 const int32_t input2_val = params.
input2_offset + input2_data;
231 const int32_t shifted_input1_val = input1_val * (1 << params.
left_shift);
232 const int32_t shifted_input2_val = input2_val * (1 << params.
left_shift);
237 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
243 return clamped_output;
247 const uint8_t *input1_data,
const uint8_t *input2_data,
248 uint8_t *output_data)
255 for (; i <=
size - 8; i += 8)
257 const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
258 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
259 const int16x8_t input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
260 const int16x8_t input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
261 const int16x8_t input1_val = vaddq_s16(input1_val_s16, vdupq_n_s16(params.
input1_offset));
262 const int16x8_t input2_val = vaddq_s16(input2_val_s16, vdupq_n_s16(params.
input2_offset));
263 const int16x4_t input1_val_high = vget_high_s16(input1_val);
264 const int16x4_t input1_val_low = vget_low_s16(input1_val);
265 const int16x4_t input2_val_high = vget_high_s16(input2_val);
266 const int16x4_t input2_val_low = vget_low_s16(input2_val);
267 int32x4_t x11 = vmovl_s16(input1_val_low);
268 int32x4_t x12 = vmovl_s16(input1_val_high);
269 int32x4_t x21 = vmovl_s16(input2_val_low);
270 int32x4_t x22 = vmovl_s16(input2_val_high);
271 const int32x4_t left_shift_dup = vdupq_n_s32(params.
left_shift);
272 x11 = vshlq_s32(x11, left_shift_dup);
273 x12 = vshlq_s32(x12, left_shift_dup);
274 x21 = vshlq_s32(x21, left_shift_dup);
275 x22 = vshlq_s32(x22, left_shift_dup);
280 const int32x4_t input1_shift_dup = vdupq_n_s32(params.
input1_shift);
281 const int32x4_t input2_shift_dup = vdupq_n_s32(params.
input2_shift);
282 x11 = vshlq_s32(x11, input1_shift_dup);
283 x12 = vshlq_s32(x12, input1_shift_dup);
284 x21 = vshlq_s32(x21, input2_shift_dup);
285 x22 = vshlq_s32(x22, input2_shift_dup);
286 int32x4_t s1 = vaddq_s32(x11, x21);
287 int32x4_t s2 = vaddq_s32(x12, x22);
290 using gemmlowp::RoundingDivideByPOT;
293 const int16x4_t s1_narrowed = vmovn_s32(s1);
294 const int16x4_t s2_narrowed = vmovn_s32(s2);
296 vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(params.
output_offset));
297 const uint8x8_t clamped =
298 vmax_u8(output_activation_min_vector, vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
299 vst1_u8(output_data + i, clamped);
302 for (; i <
size; ++i)
304 const int32_t input1_val = params.
input1_offset + input1_data[i];
305 const int32_t input2_val = params.
input2_offset + input2_data[i];
306 const int32_t shifted_input1_val = input1_val * (1 << params.
left_shift);
307 const int32_t shifted_input2_val = input2_val * (1 << params.
left_shift);
312 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
318 output_data[i] =
static_cast<uint8_t
>(clamped_output);
323 const int8_t *input1_data,
const int8_t *input2_data,
333 const int32x4_t input1_left_dup = vdupq_n_s32(input1_left_shift);
334 const int32x4_t input2_left_dup = vdupq_n_s32(input2_left_shift);
336 const int16x8_t input1_offset_dup = vdupq_n_s16(params.
input1_offset);
337 const int16x8_t input2_offset_dup = vdupq_n_s16(params.
input2_offset);
339 for (; i <=
size - 16; i += 16)
341 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
342 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
344 const int16x8_t input1_val_s16_high = vmovl_s8(vget_high_s8(input1_val_original));
345 const int16x8_t input1_val_s16_low = vmovl_s8(vget_low_s8(input1_val_original));
347 const int16x8_t input2_val_s16_high = vmovl_s8(vget_high_s8(input2_val_original));
348 const int16x8_t input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
349 const int16x8_t input1_val_high = vaddq_s16(input1_val_s16_high, input1_offset_dup);
350 const int16x8_t input2_val_high = vaddq_s16(input2_val_s16_high, input2_offset_dup);
351 const int16x8_t input1_val_low = vaddq_s16(input1_val_s16_low, input1_offset_dup);
352 const int16x8_t input2_val_low = vaddq_s16(input2_val_s16_low, input2_offset_dup);
353 const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
354 const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
355 const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
356 const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
357 const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
358 const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
359 const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
360 const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
361 int32x4_t x111 = vmovl_s16(input1_val_low_low);
362 int32x4_t x112 = vmovl_s16(input1_val_low_high);
363 int32x4_t x121 = vmovl_s16(input1_val_high_low);
364 int32x4_t x122 = vmovl_s16(input1_val_high_high);
365 int32x4_t x211 = vmovl_s16(input2_val_low_low);
366 int32x4_t x212 = vmovl_s16(input2_val_low_high);
367 int32x4_t x221 = vmovl_s16(input2_val_high_low);
368 int32x4_t x222 = vmovl_s16(input2_val_high_high);
370 x111 = vshlq_s32(x111, input1_left_dup);
371 x112 = vshlq_s32(x112, input1_left_dup);
372 x121 = vshlq_s32(x121, input1_left_dup);
373 x122 = vshlq_s32(x122, input1_left_dup);
374 x211 = vshlq_s32(x211, input2_left_dup);
375 x212 = vshlq_s32(x212, input2_left_dup);
376 x221 = vshlq_s32(x221, input2_left_dup);
377 x222 = vshlq_s32(x222, input2_left_dup);
386 int32x4_t s11 = vaddq_s32(x111, x211);
387 int32x4_t s12 = vaddq_s32(x112, x212);
388 int32x4_t s21 = vaddq_s32(x121, x221);
389 int32x4_t s22 = vaddq_s32(x122, x222);
394 using gemmlowp::RoundingDivideByPOT;
399 const int16x4_t s11_narrowed = vmovn_s32(s11);
400 const int16x4_t s12_narrowed = vmovn_s32(s12);
401 const int16x4_t s21_narrowed = vmovn_s32(s21);
402 const int16x4_t s22_narrowed = vmovn_s32(s22);
404 vaddq_s16(vcombine_s16(s11_narrowed, s12_narrowed), vdupq_n_s16(params.
output_offset));
406 vaddq_s16(vcombine_s16(s21_narrowed, s22_narrowed), vdupq_n_s16(params.
output_offset));
407 const int8x16_t s = vcombine_s8(vqmovn_s16(s1), vqmovn_s16(s2));
409 const int8x16_t clamped =
410 vmaxq_s8(output_activation_min_vector, vminq_s8(output_activation_max_vector, s));
411 vst1q_s8(output_data + i, clamped);
415 for (; i <
size; ++i)
417 const int32_t input1_val = params.
input1_offset + input1_data[i];
418 const int32_t input2_val = params.
input2_offset + input2_data[i];
419 const int32_t shifted_input1_val = input1_val * (1 << params.
left_shift);
420 const int32_t shifted_input2_val = input2_val * (1 << params.
left_shift);
425 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
431 output_data[i] =
static_cast<int8_t
>(clamped_output);
438 static inline float32x4_t
calculate(
const float32x4_t &a,
const float32x4_t &b)
440 return vaddq_f32(a, b);
443 static inline float calculate(
const float a,
const float b) {
return a + b; }
449 static inline float32x4_t
calculate(
const float32x4_t &a,
const float32x4_t &b)
451 return vsubq_f32(a, b);
454 static inline float calculate(
const float a,
const float b) {
return a - b; }
460 static inline float32x4_t
calculate(
const float32x4_t &a,
const float32x4_t &b)
462 return vmulq_f32(a, b);
465 static inline float calculate(
const float a,
const float b) {
return a * b; }
472 static inline float32x4_t
calculate(
const float32x4_t &a,
const float32x4_t &b)
474 return vdivq_f32(a, b);
478 static inline float calculate(
const float a,
const float b) {
return a / b; }
483 template <
typename T>
static inline T
calculate(
const T &a,
const T &b)
485 return BASEOPERATOR::calculate(b, a);
492 static inline float32x4_t
applyCeiling(
const float32x4_t &value,
const float32x4_t &ceilingParam)
497 static inline float32x4_t
applyFloor(
const float32x4_t &value,
const float32x4_t &floorParam)
503 static inline float applyCeiling(
const float value,
const float ceilingParam)
508 static inline float applyFloor(
const float value,
const float floorParam)
518 static inline float32x4_t
applyCeiling(
const float32x4_t &value,
const float32x4_t &ceilingParam)
523 static inline float32x4_t
applyFloor(
const float32x4_t &value,
const float32x4_t &floorParam)
525 return vmaxq_f32(value, floorParam);
528 static inline float applyCeiling(
const float value,
const float ceilingParam)
533 static inline float applyFloor(
const float value,
const float floorParam)
535 return std::max(value, floorParam);
542 static inline float32x4_t
applyCeiling(
const float32x4_t &value,
const float32x4_t &ceilingParam)
544 return vminq_f32(value, ceilingParam);
546 static inline float32x4_t
applyFloor(
const float32x4_t &value,
const float32x4_t &floorParam)
548 return vmaxq_f32(value, floorParam);
551 static inline float applyCeiling(
const float value,
const float ceilingParam)
553 return std::min(value, ceilingParam);
555 static inline float applyFloor(
const float value,
const float floorParam)
557 return std::max(value, floorParam);
561template <
class OPERATOR,
class ACTIVATION>
563 const float *input1_data,
const float *input2_data,
571 for (; i <=
size - 16; i += 16)
573 auto a10 = vld1q_f32(input1_data + i);
574 auto a11 = vld1q_f32(input1_data + i + 4);
575 auto a12 = vld1q_f32(input1_data + i + 8);
576 auto a13 = vld1q_f32(input1_data + i + 12);
577 auto a20 = vld1q_f32(input2_data + i);
578 auto a21 = vld1q_f32(input2_data + i + 4);
579 auto a22 = vld1q_f32(input2_data + i + 8);
580 auto a23 = vld1q_f32(input2_data + i + 12);
581 auto x0 = OPERATOR::calculate(a10, a20);
582 auto x1 = OPERATOR::calculate(a11, a21);
583 auto x2 = OPERATOR::calculate(a12, a22);
584 auto x3 = OPERATOR::calculate(a13, a23);
585 x0 = ACTIVATION::applyFloor(x0, activation_min);
586 x1 = ACTIVATION::applyFloor(x1, activation_min);
587 x2 = ACTIVATION::applyFloor(x2, activation_min);
588 x3 = ACTIVATION::applyFloor(x3, activation_min);
589 x0 = ACTIVATION::applyCeiling(x0, activation_max);
590 x1 = ACTIVATION::applyCeiling(x1, activation_max);
591 x2 = ACTIVATION::applyCeiling(x2, activation_max);
592 x3 = ACTIVATION::applyCeiling(x3, activation_max);
593 vst1q_f32(output_data + i, x0);
594 vst1q_f32(output_data + i + 4, x1);
595 vst1q_f32(output_data + i + 8, x2);
596 vst1q_f32(output_data + i + 12, x3);
598 for (; i <=
size - 4; i += 4)
600 auto a1 = vld1q_f32(input1_data + i);
601 auto a2 = vld1q_f32(input2_data + i);
602 auto x = OPERATOR::calculate(a1, a2);
604 ACTIVATION::applyCeiling(ACTIVATION::applyFloor(x, activation_min), activation_max);
605 vst1q_f32(output_data + i, x_clamped);
608 for (; i <
size; i++)
610 auto x = OPERATOR::calculate(input1_data[i], input2_data[i]);
611 output_data[i] = ACTIVATION::applyCeiling(
619template <
class OPERATOR,
class ACTIVATION>
621 const float broadcast_value,
const float *input2_data,
629 const auto broadcast_value_dup = vdupq_n_f32(broadcast_value);
630 for (; i <=
size - 16; i += 16)
632 auto a20 = vld1q_f32(input2_data + i);
633 auto a21 = vld1q_f32(input2_data + i + 4);
634 auto a22 = vld1q_f32(input2_data + i + 8);
635 auto a23 = vld1q_f32(input2_data + i + 12);
636 auto x0 = OPERATOR::calculate(broadcast_value_dup, a20);
637 auto x1 = OPERATOR::calculate(broadcast_value_dup, a21);
638 auto x2 = OPERATOR::calculate(broadcast_value_dup, a22);
639 auto x3 = OPERATOR::calculate(broadcast_value_dup, a23);
640 x0 = ACTIVATION::applyFloor(x0, activation_min);
641 x1 = ACTIVATION::applyFloor(x1, activation_min);
642 x2 = ACTIVATION::applyFloor(x2, activation_min);
643 x3 = ACTIVATION::applyFloor(x3, activation_min);
644 x0 = ACTIVATION::applyCeiling(x0, activation_max);
645 x1 = ACTIVATION::applyCeiling(x1, activation_max);
646 x2 = ACTIVATION::applyCeiling(x2, activation_max);
647 x3 = ACTIVATION::applyCeiling(x3, activation_max);
648 vst1q_f32(output_data + i, x0);
649 vst1q_f32(output_data + i + 4, x1);
650 vst1q_f32(output_data + i + 8, x2);
651 vst1q_f32(output_data + i + 12, x3);
653 for (; i <=
size - 4; i += 4)
655 auto a2 = vld1q_f32(input2_data + i);
656 auto x = OPERATOR::calculate(broadcast_value_dup, a2);
658 ACTIVATION::applyCeiling(ACTIVATION::applyFloor(x, activation_min), activation_max);
659 vst1q_f32(output_data + i, x_clamped);
662 for (; i <
size; i++)
664 auto x = OPERATOR::calculate(broadcast_value, input2_data[i]);
665 output_data[i] = ACTIVATION::applyCeiling(
681 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatNone>);
684 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatMax>);
687 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatMinMax>);
691inline typename std::enable_if_t<is_quant8<T>::value>
696 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
700 const float *input1_data,
const Shape &input2_shape,
const float *input2_data,
704 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncAddFloat>(params);
705 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
712 uint8_t broadcast_value,
const uint8_t *input2_data,
713 uint8_t *output_data)
716 int32_t clamped_output;
717 for (; i <
size; ++i)
719 clamped_output =
quant8_sum(params, broadcast_value, input2_data[i]);
720 output_data[i] =
static_cast<uint8_t
>(clamped_output);
728 const int8_t *input2_data, int8_t *output_data)
730 using gemmlowp::RoundingDivideByPOT;
733 const int32x4_t left_shift_dup = vdupq_n_s32(params.
left_shift);
738 const int8x8_t input1_val_original = vdup_n_s8(input1_data);
739 const int16x8_t input1_val_s16 = vmovl_s8(input1_val_original);
740 const int16x8_t input1_val = vaddq_s16(input1_val_s16, vdupq_n_s16(params.
input1_offset));
741 const int16x4_t input1_val_high = vget_high_s16(input1_val);
742 const int16x4_t input1_val_low = vget_low_s16(input1_val);
743 int32x4_t x11 = vmovl_s16(input1_val_low);
744 int32x4_t x12 = vmovl_s16(input1_val_high);
745 x11 = vshlq_s32(x11, left_shift_dup);
746 x12 = vshlq_s32(x12, left_shift_dup);
749 const int32x4_t input1_shift_dup = vdupq_n_s32(params.
input1_shift);
750 x11 = vshlq_s32(x11, input1_shift_dup);
751 x12 = vshlq_s32(x12, input1_shift_dup);
753 for (; i <=
size - 8; i += 8)
755 const int8x8_t input2_val_original = vld1_s8(input2_data + i);
756 const int16x8_t input2_val_s16 = vmovl_s8(input2_val_original);
757 const int16x8_t input2_val = vaddq_s16(input2_val_s16, vdupq_n_s16(params.
input2_offset));
758 const int16x4_t input2_val_high = vget_high_s16(input2_val);
759 const int16x4_t input2_val_low = vget_low_s16(input2_val);
760 int32x4_t x21 = vmovl_s16(input2_val_low);
761 int32x4_t x22 = vmovl_s16(input2_val_high);
762 x21 = vshlq_s32(x21, left_shift_dup);
763 x22 = vshlq_s32(x22, left_shift_dup);
766 const int32x4_t input2_shift_dup = vdupq_n_s32(params.
input2_shift);
767 x21 = vshlq_s32(x21, input2_shift_dup);
768 x22 = vshlq_s32(x22, input2_shift_dup);
769 int32x4_t s1 = vaddq_s32(x11, x21);
770 int32x4_t s2 = vaddq_s32(x12, x22);
775 const int16x4_t s1_narrowed = vmovn_s32(s1);
776 const int16x4_t s2_narrowed = vmovn_s32(s2);
778 vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(params.
output_offset));
779 const int8x8_t clamped =
780 vmax_s8(output_activation_min_vector, vmin_s8(output_activation_max_vector, vqmovn_s16(s)));
781 vst1_s8(output_data + i, clamped);
788 const int32_t input1_val = params.
input1_offset + input1_data;
789 const int32_t shifted_input1_val = input1_val * (1 << params.
left_shift);
793 for (; i <
size; ++i)
795 const int32_t input2_val = params.
input2_offset + input2_data[i];
796 const int32_t shifted_input2_val = input2_val * (1 << params.
left_shift);
799 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
803 const int32_t clamped_output = std::min(
805 output_data[i] =
static_cast<int8_t
>(clamped_output);
811inline typename std::enable_if_t<is_quant8<T>::value>
813 const T *input1_data,
const Shape &input2_shape,
const T *input2_data,
820 return static_cast<T
>(
quant8_sum(params, a, b));
828 params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data,
836 const float *input1_data,
const Shape &input2_shape,
842 const std::function<float(
const float &,
const float &)> fn =
843 [](
const float &a,
const float &b) ->
float {
return a + b; };
849 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncAddFloat>(params);
853 input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data,
854 implFuncs.first, implFuncs.second);
859 const float *input1_data,
const Shape &input2_shape,
const float *input2_data,
863 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncSubFloat>(params);
864 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
868 const float *input1_data,
const Shape &input2_shape,
874 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncSubFloat>(params);
876 output_shape, output_data, implFuncs.first, implFuncs.second);
881 getBinaryOpWithActivationImplFloat<BinaryOpFuncSwapArgs<BinaryOpFuncSubFloat>>(params);
883 output_shape, output_data, implFuncs.first, implFuncs.second);
887 const std::function<float(
const float &,
const float &)> fn =
888 [](
const float &a,
const float &b) ->
float {
return a - b; };
895inline typename std::enable_if_t<is_quant8<T>::value, int32_t>
898 const int32_t input1_val = params.
input1_offset + input1_data;
899 const int32_t input2_val = params.
input2_offset + input2_data;
900 const int32_t unclamped_result =
904 const int32_t clamped_output = std::min(
907 return clamped_output;
911 const uint8_t *input1_data,
const uint8_t *input2_data,
912 uint8_t *output_data)
917 const auto input1_offset_vector = vdupq_n_s16(params.
input1_offset);
918 const auto input2_offset_vector = vdupq_n_s16(params.
input2_offset);
919 const auto output_offset_vector = vdupq_n_s16(params.
output_offset);
922 const int left_shift = std::max(0, params.
output_shift);
923 const int right_shift = std::max(0, -params.
output_shift);
924 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
925 for (; i <=
size - 8; i += 8)
928 const auto input1_val_original = vld1_u8(input1_data + i);
929 const auto input2_val_original = vld1_u8(input2_data + i);
930 const auto input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
931 const auto input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
932 const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
933 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
935 const auto input1_val_low = vget_low_s16(input1_val);
936 const auto input1_val_high = vget_high_s16(input1_val);
937 const auto input2_val_low = vget_low_s16(input2_val);
938 const auto input2_val_high = vget_high_s16(input2_val);
940 auto p1 = vmull_s16(input2_val_low, input1_val_low);
941 auto p2 = vmull_s16(input2_val_high, input1_val_high);
943 p1 = vshlq_s32(p1, left_shift_vec);
944 p2 = vshlq_s32(p2, left_shift_vec);
947 using gemmlowp::RoundingDivideByPOT;
948 p1 = RoundingDivideByPOT(p1, right_shift);
949 p2 = RoundingDivideByPOT(p2, right_shift);
951 const auto p1_narrowed = vqmovn_s32(p1);
952 const auto p2_narrowed = vqmovn_s32(p2);
953 const auto p = vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
955 vmax_u8(output_activation_min_vector, vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
956 vst1_u8(output_data + i, clamped);
960 for (; i <
size; ++i)
962 const int32_t input1_val = params.
input1_offset + input1_data[i];
963 const int32_t input2_val = params.
input2_offset + input2_data[i];
964 const int32_t unclamped_result =
968 const int32_t clamped_output = std::min(
970 output_data[i] =
static_cast<uint8_t
>(clamped_output);
975 const int8_t *input1_data,
const int8_t *input2_data,
980 const int16x8_t input1_offset_vector = vdupq_n_s16(params.
input1_offset);
981 const int16x8_t input2_offset_vector = vdupq_n_s16(params.
input2_offset);
982 const int16x8_t output_offset_vector = vdupq_n_s16(params.
output_offset);
985 const int left_shift = std::max(0, params.
output_shift);
986 const int right_shift = std::max(0, -params.
output_shift);
987 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
988 for (; i <=
size - 16; i += 16)
991 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
992 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
994 const int16x8_t input1_val_s16_high = vmovl_s8(vget_high_s8(input1_val_original));
995 const int16x8_t input1_val_s16_low = vmovl_s8(vget_low_s8(input1_val_original));
997 const int16x8_t input2_val_s16_high = vmovl_s8(vget_high_s8(input2_val_original));
998 const int16x8_t input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
999 const int16x8_t input1_val_high = vaddq_s16(input1_val_s16_high, input1_offset_vector);
1000 const int16x8_t input2_val_high = vaddq_s16(input2_val_s16_high, input2_offset_vector);
1001 const int16x8_t input1_val_low = vaddq_s16(input1_val_s16_low, input1_offset_vector);
1002 const int16x8_t input2_val_low = vaddq_s16(input2_val_s16_low, input2_offset_vector);
1003 const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
1004 const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
1005 const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
1006 const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
1007 const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
1008 const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
1009 const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
1010 const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
1012 auto p1 = vmull_s16(input2_val_high_high, input1_val_high_high);
1013 auto p2 = vmull_s16(input2_val_high_low, input1_val_high_low);
1014 auto p3 = vmull_s16(input2_val_low_high, input1_val_low_high);
1015 auto p4 = vmull_s16(input2_val_low_low, input1_val_low_low);
1017 p1 = vshlq_s32(p1, left_shift_vec);
1018 p2 = vshlq_s32(p2, left_shift_vec);
1019 p3 = vshlq_s32(p3, left_shift_vec);
1020 p4 = vshlq_s32(p4, left_shift_vec);
1026 using gemmlowp::RoundingDivideByPOT;
1027 p1 = RoundingDivideByPOT(p1, right_shift);
1028 p2 = RoundingDivideByPOT(p2, right_shift);
1029 p3 = RoundingDivideByPOT(p3, right_shift);
1030 p4 = RoundingDivideByPOT(p4, right_shift);
1032 const auto p1_narrowed = vqmovn_s32(p1);
1033 const auto p2_narrowed = vqmovn_s32(p2);
1034 const auto p3_narrowed = vqmovn_s32(p3);
1035 const auto p4_narrowed = vqmovn_s32(p4);
1037 const int16x8_t p_part1 =
1038 vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
1039 const int16x8_t p_part2 =
1040 vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
1041 const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
1043 const auto clamped =
1044 vmaxq_s8(output_activation_min_vector, vminq_s8(output_activation_max_vector, p));
1045 vst1q_s8(output_data + i, clamped);
1049 for (; i <
size; ++i)
1051 const int32_t input1_val = params.
input1_offset + input1_data[i];
1052 const int32_t input2_val = params.
input2_offset + input2_data[i];
1053 const int32_t unclamped_result =
1057 const int32_t clamped_output = std::min(
1059 output_data[i] =
static_cast<int8_t
>(clamped_output);
1063template <
typename T>
1064inline typename std::enable_if_t<is_quant8<T>::value>
1069 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
1073 const float *input1_data,
const Shape &input2_shape,
const float *input2_data,
1077 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncMulFloat>(params);
1078 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
1082 const uint8_t broadcast_value,
const uint8_t *input2_data,
1083 uint8_t *output_data)
1086 int32_t clamped_output;
1087 for (; i <
size; ++i)
1089 clamped_output =
quant8_mul(params, broadcast_value, input2_data[i]);
1090 output_data[i] =
static_cast<uint8_t
>(clamped_output);
1096 const int8_t broadcast_value,
const int8_t *input2_data,
1097 int8_t *output_data)
1099 const int16_t input1_val = params.
input1_offset + broadcast_value;
1103 const auto input2_offset_vector = vdupq_n_s16(params.
input2_offset);
1104 const auto output_offset_vector = vdupq_n_s16(params.
output_offset);
1107 const int left_shift = std::max(0, params.
output_shift);
1108 const int right_shift = std::max(0, -params.
output_shift);
1109 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
1110 for (; i <=
size - 16; i += 16)
1113 const auto input2_val_original = vld1q_s8(input2_data + i);
1114 const auto input2_val_s16_high = vmovl_s8(vget_high_s8(input2_val_original));
1115 const auto input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
1117 const auto input2_val_high = vaddq_s16(input2_val_s16_high, input2_offset_vector);
1118 const auto input2_val_low = vaddq_s16(input2_val_s16_low, input2_offset_vector);
1120 const auto input2_val_low_low = vget_low_s16(input2_val_low);
1121 const auto input2_val_low_high = vget_high_s16(input2_val_low);
1122 const auto input2_val_high_low = vget_low_s16(input2_val_high);
1123 const auto input2_val_high_high = vget_high_s16(input2_val_high);
1125 auto p1 = vmull_n_s16(input2_val_high_high, input1_val);
1126 auto p2 = vmull_n_s16(input2_val_high_low, input1_val);
1127 auto p3 = vmull_n_s16(input2_val_low_high, input1_val);
1128 auto p4 = vmull_n_s16(input2_val_low_low, input1_val);
1130 p1 = vshlq_s32(p1, left_shift_vec);
1131 p2 = vshlq_s32(p2, left_shift_vec);
1132 p3 = vshlq_s32(p3, left_shift_vec);
1133 p4 = vshlq_s32(p4, left_shift_vec);
1139 using gemmlowp::RoundingDivideByPOT;
1140 p1 = RoundingDivideByPOT(p1, right_shift);
1141 p2 = RoundingDivideByPOT(p2, right_shift);
1142 p3 = RoundingDivideByPOT(p3, right_shift);
1143 p4 = RoundingDivideByPOT(p4, right_shift);
1145 const auto p1_narrowed = vqmovn_s32(p1);
1146 const auto p2_narrowed = vqmovn_s32(p2);
1147 const auto p3_narrowed = vqmovn_s32(p3);
1148 const auto p4_narrowed = vqmovn_s32(p4);
1150 const int16x8_t p_part1 =
1151 vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
1152 const int16x8_t p_part2 =
1153 vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
1154 const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
1156 const auto clamped =
1157 vmaxq_s8(output_activation_min_vector, vminq_s8(output_activation_max_vector, p));
1158 vst1q_s8(output_data + i, clamped);
1162 for (; i <
size; ++i)
1164 const int32_t input2_val = params.
input2_offset + input2_data[i];
1165 const int32_t unclamped_result =
1169 const int32_t clamped_output = std::min(
1171 output_data[i] =
static_cast<int8_t
>(clamped_output);
1175template <
typename T>
1176inline typename std::enable_if_t<is_quant8<T>::value>
1178 const T *input1_data,
const Shape &input2_shape,
const T *input2_data,
1185 return static_cast<T
>(
quant8_mul(params, a, b));
1192 params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data,
1200 const float *input1_data,
const Shape &input2_shape,
1207 const std::function<float(
const float &,
const float &)> fn =
1208 [](
const float &a,
const float &b) ->
float {
return a * b; };
1213 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncMulFloat>(params);
1215 output_shape, output_data, implFuncs.first, implFuncs.second);
1219 const float *input1_data,
const Shape &input2_shape,
const float *input2_data,
1224 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncDivFloat>(params);
1225 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
1227 const std::function<float(
const float &,
const float &)> fn =
1228 [](
const float &a,
const float &b) ->
float {
return a / b; };
1235 const float *input1_data,
const Shape &input2_shape,
1242 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncDivFloat>(params);
1244 output_shape, output_data, implFuncs.first, implFuncs.second);
1249 getBinaryOpWithActivationImplFloat<BinaryOpFuncSwapArgs<BinaryOpFuncDivFloat>>(params);
1251 output_shape, output_data, implFuncs.first, implFuncs.second);
1256 const std::function<float(
const float &,
const float &)> fn =
1257 [](
const float &a,
const float &b) ->
float {
return a / b; };
const luci_interpreter::RuntimeShape output_shape
void BroadcastDivDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
std::enable_if_t< is_quant8< T >::value > Mul(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
void MulSimpleBroadcast(int size, const BinaryArithmeticOpParam ¶ms, const uint8_t broadcast_value, const uint8_t *input2_data, uint8_t *output_data)
std::enable_if_t< is_quant8< T >::value, int32_t > quant8_mul(const BinaryArithmeticOpParam ¶ms, const T input1_data, const T input2_data)
std::enable_if_t< is_quant8< T >::value, int32_t > quant8_sum(const BinaryArithmeticOpParam ¶ms, const T input1_data, const T input2_data)
void AddElementwise(int size, const BinaryArithmeticOpParam ¶ms, const uint8_t *input1_data, const uint8_t *input2_data, uint8_t *output_data)
void BinaryOpScalarBroadcast(int size, const BinaryArithmeticOpParam ¶ms, const float broadcast_value, const float *input2_data, float *output_data)
std::enable_if_t< is_quant8< T >::value > BroadcastAddDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
void Sub(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
void MulElementwise(int size, const BinaryArithmeticOpParam ¶ms, const uint8_t *input1_data, const uint8_t *input2_data, uint8_t *output_data)
void AddScalarBroadcast(int size, const BinaryArithmeticOpParam ¶ms, uint8_t broadcast_value, const uint8_t *input2_data, uint8_t *output_data)
std::enable_if_t< is_quant8< T >::value > BroadcastMulDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
void BroadcastSubDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
BinaryOpImplFloatFuncs getBinaryOpWithActivationImplFloat(const BinaryArithmeticOpParam ¶ms)
void BinaryOpElementwise(int size, const BinaryArithmeticOpParam ¶ms, const float *input1_data, const float *input2_data, float *output_data)
void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam ¶ms, bool switch_inputs, const Shape &, const T *unswitched_input1_data, const Shape &, const T *unswitched_input2_data, const Shape &, T *output_data, ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
void Div(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
std::pair< void(*)(int, const BinaryArithmeticOpParam &, const float *, const float *, float *), void(*)(int, const BinaryArithmeticOpParam &, const float, const float *, float *)> BinaryOpImplFloatFuncs
std::enable_if_t< is_quant8< T >::value > Add(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
std::enable_if_t< is_quant8< T >::value > BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data, const std::function< T(const BinaryArithmeticOpParam ¶ms, const T &, const T &)> &fn)
void BinaryArithmeticOp(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data, const std::function< T(const T &, const T &)> &fn)
@ kSecondInputBroadcastsFast
@ kFirstInputBroadcastsFast
int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0, const Shape &check_shape_1)
int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x, int32_t quantized_multiplier, int left_shift)
int32_t quantized_activation_max
BroadcastableOpCategory broadcast_category
int32_t input2_multiplier
int32_t quantized_activation_min
float float_activation_max
int32_t output_multiplier
int32_t input1_multiplier
float float_activation_min
static float applyFloor(const float value, const float floorParam)
static float applyCeiling(const float value, const float ceilingParam)
static float applyFloor(const float value, const float floorParam)
static float applyCeiling(const float value, const float ceilingParam)
static float applyFloor(const float value, const float floorParam)
static float applyCeiling(const float value, const float ceilingParam)
static float calculate(const float a, const float b)
static float calculate(const float a, const float b)
static float calculate(const float a, const float b)
static float calculate(const float a, const float b)
static T calculate(const T &a, const T &b)