18#ifndef __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
19#define __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
29#include "fixedpoint/fixedpoint.h"
39template <
typename ElementwiseF,
typename ScalarBroadcastF,
typename T>
42 const T *unswitched_input1_data,
44 const T *unswitched_input2_data,
45 const Shape & , T *output_data,
46 ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
48 const T *input1_data = switch_inputs ? unswitched_input2_data : unswitched_input1_data;
49 const T *input2_data = switch_inputs ? unswitched_input1_data : unswitched_input2_data;
55 T *output_data_ptr = output_data;
56 const T *input1_data_ptr = input1_data;
57 const T *input2_data_reset = input2_data;
73 for (
int i0 = 0; i0 < y0; ++i0)
75 const T *input2_data_ptr =
nullptr;
76 for (
int i1 = 0; i1 < y1; ++i1)
78 input2_data_ptr = input2_data_reset;
79 for (
int i2 = 0; i2 < y2; ++i2)
81 for (
int i3 = 0; i3 < y3; ++i3)
83 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
84 input2_data_ptr += y4;
85 output_data_ptr += y4;
88 input1_data_ptr += y4;
92 input2_data_reset = input2_data_ptr;
107 for (
int i0 = 0; i0 < y0; ++i0)
109 const T *input2_data_ptr =
nullptr;
110 for (
int i1 = 0; i1 < y1; ++i1)
112 input2_data_ptr = input2_data_reset;
113 for (
int i2 = 0; i2 < y2; ++i2)
115 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
116 input2_data_ptr += y3;
117 output_data_ptr += y3;
118 input1_data_ptr += 1;
121 input2_data_reset = input2_data_ptr;
127template <
typename ElementwiseF,
typename ScalarBroadcastF,
typename T>
130 const T *unswitched_input1_data,
132 const T *unswitched_input2_data,
133 const Shape & , T *output_data,
134 ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
144 const bool use_unswitched =
148 const T *input1_data = use_unswitched ? unswitched_input1_data : unswitched_input2_data;
149 const T *input2_data = use_unswitched ? unswitched_input2_data : unswitched_input1_data;
155 T *output_data_ptr = output_data;
156 const T *input1_data_ptr = input1_data;
157 const T *input2_data_reset = input2_data;
172 for (
int i0 = 0; i0 < y0; ++i0)
174 const T *input2_data_ptr =
nullptr;
175 for (
int i1 = 0; i1 < y1; ++i1)
177 input2_data_ptr = input2_data_reset;
178 for (
int i2 = 0; i2 < y2; ++i2)
180 for (
int i3 = 0; i3 < y3; ++i3)
182 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
183 input2_data_ptr += y4;
184 output_data_ptr += y4;
187 input1_data_ptr += y4;
191 input2_data_reset = input2_data_ptr;
206 for (
int i0 = 0; i0 < y0; ++i0)
208 const T *input2_data_ptr =
nullptr;
209 for (
int i1 = 0; i1 < y1; ++i1)
211 input2_data_ptr = input2_data_reset;
212 for (
int i2 = 0; i2 < y2; ++i2)
214 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
215 input2_data_ptr += y3;
216 output_data_ptr += y3;
217 input1_data_ptr += 1;
220 input2_data_reset = input2_data_ptr;
226inline typename std::enable_if_t<is_quant8<T>::value, int32_t>
229 const int32_t input1_val = params.
input1_offset + input1_data;
230 const int32_t input2_val = params.
input2_offset + input2_data;
231 const int32_t shifted_input1_val = input1_val * (1 << params.
left_shift);
232 const int32_t shifted_input2_val = input2_val * (1 << params.
left_shift);
237 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
243 return clamped_output;
247 const uint8_t *input1_data,
const uint8_t *input2_data,
248 uint8_t *output_data)
255 for (; i <=
size - 8; i += 8)
257 const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
258 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
259 const int16x8_t input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
260 const int16x8_t input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
261 const int16x8_t input1_val = vaddq_s16(input1_val_s16, vdupq_n_s16(params.
input1_offset));
262 const int16x8_t input2_val = vaddq_s16(input2_val_s16, vdupq_n_s16(params.
input2_offset));
263 const int16x4_t input1_val_high = vget_high_s16(input1_val);
264 const int16x4_t input1_val_low = vget_low_s16(input1_val);
265 const int16x4_t input2_val_high = vget_high_s16(input2_val);
266 const int16x4_t input2_val_low = vget_low_s16(input2_val);
267 int32x4_t x11 = vmovl_s16(input1_val_low);
268 int32x4_t x12 = vmovl_s16(input1_val_high);
269 int32x4_t x21 = vmovl_s16(input2_val_low);
270 int32x4_t x22 = vmovl_s16(input2_val_high);
271 const int32x4_t left_shift_dup = vdupq_n_s32(params.
left_shift);
272 x11 = vshlq_s32(x11, left_shift_dup);
273 x12 = vshlq_s32(x12, left_shift_dup);
274 x21 = vshlq_s32(x21, left_shift_dup);
275 x22 = vshlq_s32(x22, left_shift_dup);
280 const int32x4_t input1_shift_dup = vdupq_n_s32(params.
input1_shift);
281 const int32x4_t input2_shift_dup = vdupq_n_s32(params.
input2_shift);
282 x11 = vshlq_s32(x11, input1_shift_dup);
283 x12 = vshlq_s32(x12, input1_shift_dup);
284 x21 = vshlq_s32(x21, input2_shift_dup);
285 x22 = vshlq_s32(x22, input2_shift_dup);
286 int32x4_t s1 = vaddq_s32(x11, x21);
287 int32x4_t s2 = vaddq_s32(x12, x22);
290 using gemmlowp::RoundingDivideByPOT;
293 const int16x4_t s1_narrowed = vmovn_s32(s1);
294 const int16x4_t s2_narrowed = vmovn_s32(s2);
296 vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(params.
output_offset));
297 const uint8x8_t clamped =
298 vmax_u8(output_activation_min_vector, vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
299 vst1_u8(output_data + i, clamped);
302 for (; i <
size; ++i)
304 const int32_t input1_val = params.
input1_offset + input1_data[i];
305 const int32_t input2_val = params.
input2_offset + input2_data[i];
306 const int32_t shifted_input1_val = input1_val * (1 << params.
left_shift);
307 const int32_t shifted_input2_val = input2_val * (1 << params.
left_shift);
312 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
318 output_data[i] =
static_cast<uint8_t
>(clamped_output);
323 const int8_t *input1_data,
const int8_t *input2_data,
333 const int32x4_t input1_left_dup = vdupq_n_s32(input1_left_shift);
334 const int32x4_t input2_left_dup = vdupq_n_s32(input2_left_shift);
336 const int16x8_t input1_offset_dup = vdupq_n_s16(params.
input1_offset);
337 const int16x8_t input2_offset_dup = vdupq_n_s16(params.
input2_offset);
339 for (; i <=
size - 16; i += 16)
341 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
342 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
344 const int16x8_t input1_val_s16_high = vmovl_s8(vget_high_s8(input1_val_original));
345 const int16x8_t input1_val_s16_low = vmovl_s8(vget_low_s8(input1_val_original));
347 const int16x8_t input2_val_s16_high = vmovl_s8(vget_high_s8(input2_val_original));
348 const int16x8_t input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
349 const int16x8_t input1_val_high = vaddq_s16(input1_val_s16_high, input1_offset_dup);
350 const int16x8_t input2_val_high = vaddq_s16(input2_val_s16_high, input2_offset_dup);
351 const int16x8_t input1_val_low = vaddq_s16(input1_val_s16_low, input1_offset_dup);
352 const int16x8_t input2_val_low = vaddq_s16(input2_val_s16_low, input2_offset_dup);
353 const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
354 const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
355 const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
356 const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
357 const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
358 const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
359 const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
360 const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
361 int32x4_t x111 = vmovl_s16(input1_val_low_low);
362 int32x4_t x112 = vmovl_s16(input1_val_low_high);
363 int32x4_t x121 = vmovl_s16(input1_val_high_low);
364 int32x4_t x122 = vmovl_s16(input1_val_high_high);
365 int32x4_t x211 = vmovl_s16(input2_val_low_low);
366 int32x4_t x212 = vmovl_s16(input2_val_low_high);
367 int32x4_t x221 = vmovl_s16(input2_val_high_low);
368 int32x4_t x222 = vmovl_s16(input2_val_high_high);
370 x111 = vshlq_s32(x111, input1_left_dup);
371 x112 = vshlq_s32(x112, input1_left_dup);
372 x121 = vshlq_s32(x121, input1_left_dup);
373 x122 = vshlq_s32(x122, input1_left_dup);
374 x211 = vshlq_s32(x211, input2_left_dup);
375 x212 = vshlq_s32(x212, input2_left_dup);
376 x221 = vshlq_s32(x221, input2_left_dup);
377 x222 = vshlq_s32(x222, input2_left_dup);
386 int32x4_t s11 = vaddq_s32(x111, x211);
387 int32x4_t s12 = vaddq_s32(x112, x212);
388 int32x4_t s21 = vaddq_s32(x121, x221);
389 int32x4_t s22 = vaddq_s32(x122, x222);
394 using gemmlowp::RoundingDivideByPOT;
399 const int16x4_t s11_narrowed = vmovn_s32(s11);
400 const int16x4_t s12_narrowed = vmovn_s32(s12);
401 const int16x4_t s21_narrowed = vmovn_s32(s21);
402 const int16x4_t s22_narrowed = vmovn_s32(s22);
404 vaddq_s16(vcombine_s16(s11_narrowed, s12_narrowed), vdupq_n_s16(params.
output_offset));
406 vaddq_s16(vcombine_s16(s21_narrowed, s22_narrowed), vdupq_n_s16(params.
output_offset));
407 const int8x16_t s = vcombine_s8(vqmovn_s16(s1), vqmovn_s16(s2));
409 const int8x16_t clamped =
410 vmaxq_s8(output_activation_min_vector, vminq_s8(output_activation_max_vector, s));
411 vst1q_s8(output_data + i, clamped);
415 for (; i <
size; ++i)
417 const int32_t input1_val = params.
input1_offset + input1_data[i];
418 const int32_t input2_val = params.
input2_offset + input2_data[i];
419 const int32_t shifted_input1_val = input1_val * (1 << params.
left_shift);
420 const int32_t shifted_input2_val = input2_val * (1 << params.
left_shift);
425 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
431 output_data[i] =
static_cast<int8_t
>(clamped_output);
438 static inline float32x4_t
calculate(
const float32x4_t &a,
const float32x4_t &b)
440 return vaddq_f32(a, b);
443 static inline float calculate(
const float a,
const float b) {
return a + b; }
449 static inline float32x4_t
calculate(
const float32x4_t &a,
const float32x4_t &b)
451 return vsubq_f32(a, b);
454 static inline float calculate(
const float a,
const float b) {
return a - b; }
460 static inline float32x4_t
calculate(
const float32x4_t &a,
const float32x4_t &b)
462 return vmulq_f32(a, b);
465 static inline float calculate(
const float a,
const float b) {
return a * b; }
472 static inline float32x4_t
calculate(
const float32x4_t &a,
const float32x4_t &b)
474 return vdivq_f32(a, b);
478 static inline float calculate(
const float a,
const float b) {
return a / b; }
483 template <
typename T>
static inline T
calculate(
const T &a,
const T &b)
485 return BASEOPERATOR::calculate(b, a);
483 template <
typename T>
static inline T
calculate(
const T &a,
const T &b) {
…}
492 static inline float32x4_t
applyCeiling(
const float32x4_t &value,
const float32x4_t &ceilingParam)
497 static inline float32x4_t
applyFloor(
const float32x4_t &value,
const float32x4_t &floorParam)
503 static inline float applyCeiling(
const float value,
const float ceilingParam)
503 static inline float applyCeiling(
const float value,
const float ceilingParam) {
…}
508 static inline float applyFloor(
const float value,
const float floorParam)
508 static inline float applyFloor(
const float value,
const float floorParam) {
…}
518 static inline float32x4_t
applyCeiling(
const float32x4_t &value,
const float32x4_t &ceilingParam)
523 static inline float32x4_t
applyFloor(
const float32x4_t &value,
const float32x4_t &floorParam)
525 return vmaxq_f32(value, floorParam);
528 static inline float applyCeiling(
const float value,
const float ceilingParam)
528 static inline float applyCeiling(
const float value,
const float ceilingParam) {
…}
533 static inline float applyFloor(
const float value,
const float floorParam)
535 return std::max(value, floorParam);
533 static inline float applyFloor(
const float value,
const float floorParam) {
…}
542 static inline float32x4_t
applyCeiling(
const float32x4_t &value,
const float32x4_t &ceilingParam)
544 return vminq_f32(value, ceilingParam);
546 static inline float32x4_t
applyFloor(
const float32x4_t &value,
const float32x4_t &floorParam)
548 return vmaxq_f32(value, floorParam);
551 static inline float applyCeiling(
const float value,
const float ceilingParam)
553 return std::min(value, ceilingParam);
551 static inline float applyCeiling(
const float value,
const float ceilingParam) {
…}
555 static inline float applyFloor(
const float value,
const float floorParam)
557 return std::max(value, floorParam);
555 static inline float applyFloor(
const float value,
const float floorParam) {
…}
561template <
class OPERATOR,
class ACTIVATION>
563 const float *input1_data,
const float *input2_data,
571 for (; i <=
size - 16; i += 16)
573 auto a10 = vld1q_f32(input1_data + i);
574 auto a11 = vld1q_f32(input1_data + i + 4);
575 auto a12 = vld1q_f32(input1_data + i + 8);
576 auto a13 = vld1q_f32(input1_data + i + 12);
577 auto a20 = vld1q_f32(input2_data + i);
578 auto a21 = vld1q_f32(input2_data + i + 4);
579 auto a22 = vld1q_f32(input2_data + i + 8);
580 auto a23 = vld1q_f32(input2_data + i + 12);
581 auto x0 = OPERATOR::calculate(a10, a20);
582 auto x1 = OPERATOR::calculate(a11, a21);
583 auto x2 = OPERATOR::calculate(a12, a22);
584 auto x3 = OPERATOR::calculate(a13, a23);
585 x0 = ACTIVATION::applyFloor(x0, activation_min);
586 x1 = ACTIVATION::applyFloor(x1, activation_min);
587 x2 = ACTIVATION::applyFloor(x2, activation_min);
588 x3 = ACTIVATION::applyFloor(x3, activation_min);
589 x0 = ACTIVATION::applyCeiling(x0, activation_max);
590 x1 = ACTIVATION::applyCeiling(x1, activation_max);
591 x2 = ACTIVATION::applyCeiling(x2, activation_max);
592 x3 = ACTIVATION::applyCeiling(x3, activation_max);
593 vst1q_f32(output_data + i, x0);
594 vst1q_f32(output_data + i + 4, x1);
595 vst1q_f32(output_data + i + 8, x2);
596 vst1q_f32(output_data + i + 12, x3);
598 for (; i <=
size - 4; i += 4)
600 auto a1 = vld1q_f32(input1_data + i);
601 auto a2 = vld1q_f32(input2_data + i);
602 auto x = OPERATOR::calculate(a1, a2);
604 ACTIVATION::applyCeiling(ACTIVATION::applyFloor(x, activation_min), activation_max);
605 vst1q_f32(output_data + i, x_clamped);
608 for (; i <
size; i++)
610 auto x = OPERATOR::calculate(input1_data[i], input2_data[i]);
611 output_data[i] = ACTIVATION::applyCeiling(
619template <
class OPERATOR,
class ACTIVATION>
621 const float broadcast_value,
const float *input2_data,
629 const auto broadcast_value_dup = vdupq_n_f32(broadcast_value);
630 for (; i <=
size - 16; i += 16)
632 auto a20 = vld1q_f32(input2_data + i);
633 auto a21 = vld1q_f32(input2_data + i + 4);
634 auto a22 = vld1q_f32(input2_data + i + 8);
635 auto a23 = vld1q_f32(input2_data + i + 12);
636 auto x0 = OPERATOR::calculate(broadcast_value_dup, a20);
637 auto x1 = OPERATOR::calculate(broadcast_value_dup, a21);
638 auto x2 = OPERATOR::calculate(broadcast_value_dup, a22);
639 auto x3 = OPERATOR::calculate(broadcast_value_dup, a23);
640 x0 = ACTIVATION::applyFloor(x0, activation_min);
641 x1 = ACTIVATION::applyFloor(x1, activation_min);
642 x2 = ACTIVATION::applyFloor(x2, activation_min);
643 x3 = ACTIVATION::applyFloor(x3, activation_min);
644 x0 = ACTIVATION::applyCeiling(x0, activation_max);
645 x1 = ACTIVATION::applyCeiling(x1, activation_max);
646 x2 = ACTIVATION::applyCeiling(x2, activation_max);
647 x3 = ACTIVATION::applyCeiling(x3, activation_max);
648 vst1q_f32(output_data + i, x0);
649 vst1q_f32(output_data + i + 4, x1);
650 vst1q_f32(output_data + i + 8, x2);
651 vst1q_f32(output_data + i + 12, x3);
653 for (; i <=
size - 4; i += 4)
655 auto a2 = vld1q_f32(input2_data + i);
656 auto x = OPERATOR::calculate(broadcast_value_dup, a2);
658 ACTIVATION::applyCeiling(ACTIVATION::applyFloor(x, activation_min), activation_max);
659 vst1q_f32(output_data + i, x_clamped);
662 for (; i <
size; i++)
664 auto x = OPERATOR::calculate(broadcast_value, input2_data[i]);
665 output_data[i] = ACTIVATION::applyCeiling(
681 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatNone>);
684 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatMax>);
687 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatMinMax>);
691inline typename std::enable_if_t<is_quant8<T>::value>
696 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
700 const float *input1_data,
const Shape &input2_shape,
const float *input2_data,
704 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncAddFloat>(params);
705 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
712 uint8_t broadcast_value,
const uint8_t *input2_data,
713 uint8_t *output_data)
716 int32_t clamped_output;
717 for (; i <
size; ++i)
719 clamped_output =
quant8_sum(params, broadcast_value, input2_data[i]);
720 output_data[i] =
static_cast<uint8_t
>(clamped_output);
728 const int8_t *input2_data, int8_t *output_data)
730 using gemmlowp::RoundingDivideByPOT;
733 const int32x4_t left_shift_dup = vdupq_n_s32(params.
left_shift);
738 const int8x8_t input1_val_original = vdup_n_s8(input1_data);
739 const int16x8_t input1_val_s16 = vmovl_s8(input1_val_original);
740 const int16x8_t input1_val = vaddq_s16(input1_val_s16, vdupq_n_s16(params.
input1_offset));
741 const int16x4_t input1_val_high = vget_high_s16(input1_val);
742 const int16x4_t input1_val_low = vget_low_s16(input1_val);
743 int32x4_t x11 = vmovl_s16(input1_val_low);
744 int32x4_t x12 = vmovl_s16(input1_val_high);
745 x11 = vshlq_s32(x11, left_shift_dup);
746 x12 = vshlq_s32(x12, left_shift_dup);
749 const int32x4_t input1_shift_dup = vdupq_n_s32(params.
input1_shift);
750 x11 = vshlq_s32(x11, input1_shift_dup);
751 x12 = vshlq_s32(x12, input1_shift_dup);
753 for (; i <=
size - 8; i += 8)
755 const int8x8_t input2_val_original = vld1_s8(input2_data + i);
756 const int16x8_t input2_val_s16 = vmovl_s8(input2_val_original);
757 const int16x8_t input2_val = vaddq_s16(input2_val_s16, vdupq_n_s16(params.
input2_offset));
758 const int16x4_t input2_val_high = vget_high_s16(input2_val);
759 const int16x4_t input2_val_low = vget_low_s16(input2_val);
760 int32x4_t x21 = vmovl_s16(input2_val_low);
761 int32x4_t x22 = vmovl_s16(input2_val_high);
762 x21 = vshlq_s32(x21, left_shift_dup);
763 x22 = vshlq_s32(x22, left_shift_dup);
766 const int32x4_t input2_shift_dup = vdupq_n_s32(params.
input2_shift);
767 x21 = vshlq_s32(x21, input2_shift_dup);
768 x22 = vshlq_s32(x22, input2_shift_dup);
769 int32x4_t s1 = vaddq_s32(x11, x21);
770 int32x4_t s2 = vaddq_s32(x12, x22);
775 const int16x4_t s1_narrowed = vmovn_s32(s1);
776 const int16x4_t s2_narrowed = vmovn_s32(s2);
778 vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(params.
output_offset));
779 const int8x8_t clamped =
780 vmax_s8(output_activation_min_vector, vmin_s8(output_activation_max_vector, vqmovn_s16(s)));
781 vst1_s8(output_data + i, clamped);
788 const int32_t input1_val = params.
input1_offset + input1_data;
789 const int32_t shifted_input1_val = input1_val * (1 << params.
left_shift);
793 for (; i <
size; ++i)
795 const int32_t input2_val = params.
input2_offset + input2_data[i];
796 const int32_t shifted_input2_val = input2_val * (1 << params.
left_shift);
799 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
803 const int32_t clamped_output = std::min(
805 output_data[i] =
static_cast<int8_t
>(clamped_output);
811inline typename std::enable_if_t<is_quant8<T>::value>
813 const T *input1_data,
const Shape &input2_shape,
const T *input2_data,
820 return static_cast<T
>(
quant8_sum(params, a, b));
828 params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data,
836 const float *input1_data,
const Shape &input2_shape,
842 const std::function<float(
const float &,
const float &)> fn =
843 [](
const float &a,
const float &b) ->
float {
return a + b; };
849 auto [implFunc1, implFunc2] = getBinaryOpWithActivationImplFloat<BinaryOpFuncAddFloat>(params);
852 input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data, implFunc1,
858 const float *input1_data,
const Shape &input2_shape,
const float *input2_data,
862 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncSubFloat>(params);
863 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
867 const float *input1_data,
const Shape &input2_shape,
873 auto [implFunc1, implFunc2] = getBinaryOpWithActivationImplFloat<BinaryOpFuncSubFloat>(params);
879 auto [implFunc1, implFunc2] =
880 getBinaryOpWithActivationImplFloat<BinaryOpFuncSwapArgs<BinaryOpFuncSubFloat>>(params);
886 const std::function<float(
const float &,
const float &)> fn =
887 [](
const float &a,
const float &b) ->
float {
return a - b; };
894inline typename std::enable_if_t<is_quant8<T>::value, int32_t>
897 const int32_t input1_val = params.
input1_offset + input1_data;
898 const int32_t input2_val = params.
input2_offset + input2_data;
899 const int32_t unclamped_result =
903 const int32_t clamped_output = std::min(
906 return clamped_output;
910 const uint8_t *input1_data,
const uint8_t *input2_data,
911 uint8_t *output_data)
916 const auto input1_offset_vector = vdupq_n_s16(params.
input1_offset);
917 const auto input2_offset_vector = vdupq_n_s16(params.
input2_offset);
918 const auto output_offset_vector = vdupq_n_s16(params.
output_offset);
921 const int left_shift = std::max(0, params.
output_shift);
922 const int right_shift = std::max(0, -params.
output_shift);
923 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
924 for (; i <=
size - 8; i += 8)
927 const auto input1_val_original = vld1_u8(input1_data + i);
928 const auto input2_val_original = vld1_u8(input2_data + i);
929 const auto input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
930 const auto input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
931 const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
932 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
934 const auto input1_val_low = vget_low_s16(input1_val);
935 const auto input1_val_high = vget_high_s16(input1_val);
936 const auto input2_val_low = vget_low_s16(input2_val);
937 const auto input2_val_high = vget_high_s16(input2_val);
939 auto p1 = vmull_s16(input2_val_low, input1_val_low);
940 auto p2 = vmull_s16(input2_val_high, input1_val_high);
942 p1 = vshlq_s32(p1, left_shift_vec);
943 p2 = vshlq_s32(p2, left_shift_vec);
946 using gemmlowp::RoundingDivideByPOT;
947 p1 = RoundingDivideByPOT(p1, right_shift);
948 p2 = RoundingDivideByPOT(p2, right_shift);
950 const auto p1_narrowed = vqmovn_s32(p1);
951 const auto p2_narrowed = vqmovn_s32(p2);
952 const auto p = vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
954 vmax_u8(output_activation_min_vector, vmin_u8(output_activation_max_vector, vqmovun_s16(
p)));
955 vst1_u8(output_data + i, clamped);
959 for (; i <
size; ++i)
961 const int32_t input1_val = params.
input1_offset + input1_data[i];
962 const int32_t input2_val = params.
input2_offset + input2_data[i];
963 const int32_t unclamped_result =
967 const int32_t clamped_output = std::min(
969 output_data[i] =
static_cast<uint8_t
>(clamped_output);
974 const int8_t *input1_data,
const int8_t *input2_data,
979 const int16x8_t input1_offset_vector = vdupq_n_s16(params.
input1_offset);
980 const int16x8_t input2_offset_vector = vdupq_n_s16(params.
input2_offset);
981 const int16x8_t output_offset_vector = vdupq_n_s16(params.
output_offset);
984 const int left_shift = std::max(0, params.
output_shift);
985 const int right_shift = std::max(0, -params.
output_shift);
986 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
987 for (; i <=
size - 16; i += 16)
990 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
991 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
993 const int16x8_t input1_val_s16_high = vmovl_s8(vget_high_s8(input1_val_original));
994 const int16x8_t input1_val_s16_low = vmovl_s8(vget_low_s8(input1_val_original));
996 const int16x8_t input2_val_s16_high = vmovl_s8(vget_high_s8(input2_val_original));
997 const int16x8_t input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
998 const int16x8_t input1_val_high = vaddq_s16(input1_val_s16_high, input1_offset_vector);
999 const int16x8_t input2_val_high = vaddq_s16(input2_val_s16_high, input2_offset_vector);
1000 const int16x8_t input1_val_low = vaddq_s16(input1_val_s16_low, input1_offset_vector);
1001 const int16x8_t input2_val_low = vaddq_s16(input2_val_s16_low, input2_offset_vector);
1002 const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
1003 const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
1004 const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
1005 const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
1006 const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
1007 const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
1008 const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
1009 const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
1011 auto p1 = vmull_s16(input2_val_high_high, input1_val_high_high);
1012 auto p2 = vmull_s16(input2_val_high_low, input1_val_high_low);
1013 auto p3 = vmull_s16(input2_val_low_high, input1_val_low_high);
1014 auto p4 = vmull_s16(input2_val_low_low, input1_val_low_low);
1016 p1 = vshlq_s32(p1, left_shift_vec);
1017 p2 = vshlq_s32(p2, left_shift_vec);
1018 p3 = vshlq_s32(p3, left_shift_vec);
1019 p4 = vshlq_s32(p4, left_shift_vec);
1025 using gemmlowp::RoundingDivideByPOT;
1026 p1 = RoundingDivideByPOT(p1, right_shift);
1027 p2 = RoundingDivideByPOT(p2, right_shift);
1028 p3 = RoundingDivideByPOT(p3, right_shift);
1029 p4 = RoundingDivideByPOT(p4, right_shift);
1031 const auto p1_narrowed = vqmovn_s32(p1);
1032 const auto p2_narrowed = vqmovn_s32(p2);
1033 const auto p3_narrowed = vqmovn_s32(p3);
1034 const auto p4_narrowed = vqmovn_s32(p4);
1036 const int16x8_t p_part1 =
1037 vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
1038 const int16x8_t p_part2 =
1039 vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
1040 const int8x16_t
p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
1042 const auto clamped =
1043 vmaxq_s8(output_activation_min_vector, vminq_s8(output_activation_max_vector,
p));
1044 vst1q_s8(output_data + i, clamped);
1048 for (; i <
size; ++i)
1050 const int32_t input1_val = params.
input1_offset + input1_data[i];
1051 const int32_t input2_val = params.
input2_offset + input2_data[i];
1052 const int32_t unclamped_result =
1056 const int32_t clamped_output = std::min(
1058 output_data[i] =
static_cast<int8_t
>(clamped_output);
1062template <
typename T>
1063inline typename std::enable_if_t<is_quant8<T>::value>
1068 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
1072 const float *input1_data,
const Shape &input2_shape,
const float *input2_data,
1076 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncMulFloat>(params);
1077 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
1081 const uint8_t broadcast_value,
const uint8_t *input2_data,
1082 uint8_t *output_data)
1085 int32_t clamped_output;
1086 for (; i <
size; ++i)
1088 clamped_output =
quant8_mul(params, broadcast_value, input2_data[i]);
1089 output_data[i] =
static_cast<uint8_t
>(clamped_output);
1095 const int8_t broadcast_value,
const int8_t *input2_data,
1096 int8_t *output_data)
1098 const int16_t input1_val = params.
input1_offset + broadcast_value;
1102 const auto input2_offset_vector = vdupq_n_s16(params.
input2_offset);
1103 const auto output_offset_vector = vdupq_n_s16(params.
output_offset);
1106 const int left_shift = std::max(0, params.
output_shift);
1107 const int right_shift = std::max(0, -params.
output_shift);
1108 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
1109 for (; i <=
size - 16; i += 16)
1112 const auto input2_val_original = vld1q_s8(input2_data + i);
1113 const auto input2_val_s16_high = vmovl_s8(vget_high_s8(input2_val_original));
1114 const auto input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
1116 const auto input2_val_high = vaddq_s16(input2_val_s16_high, input2_offset_vector);
1117 const auto input2_val_low = vaddq_s16(input2_val_s16_low, input2_offset_vector);
1119 const auto input2_val_low_low = vget_low_s16(input2_val_low);
1120 const auto input2_val_low_high = vget_high_s16(input2_val_low);
1121 const auto input2_val_high_low = vget_low_s16(input2_val_high);
1122 const auto input2_val_high_high = vget_high_s16(input2_val_high);
1124 auto p1 = vmull_n_s16(input2_val_high_high, input1_val);
1125 auto p2 = vmull_n_s16(input2_val_high_low, input1_val);
1126 auto p3 = vmull_n_s16(input2_val_low_high, input1_val);
1127 auto p4 = vmull_n_s16(input2_val_low_low, input1_val);
1129 p1 = vshlq_s32(p1, left_shift_vec);
1130 p2 = vshlq_s32(p2, left_shift_vec);
1131 p3 = vshlq_s32(p3, left_shift_vec);
1132 p4 = vshlq_s32(p4, left_shift_vec);
1138 using gemmlowp::RoundingDivideByPOT;
1139 p1 = RoundingDivideByPOT(p1, right_shift);
1140 p2 = RoundingDivideByPOT(p2, right_shift);
1141 p3 = RoundingDivideByPOT(p3, right_shift);
1142 p4 = RoundingDivideByPOT(p4, right_shift);
1144 const auto p1_narrowed = vqmovn_s32(p1);
1145 const auto p2_narrowed = vqmovn_s32(p2);
1146 const auto p3_narrowed = vqmovn_s32(p3);
1147 const auto p4_narrowed = vqmovn_s32(p4);
1149 const int16x8_t p_part1 =
1150 vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
1151 const int16x8_t p_part2 =
1152 vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
1153 const int8x16_t
p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
1155 const auto clamped =
1156 vmaxq_s8(output_activation_min_vector, vminq_s8(output_activation_max_vector,
p));
1157 vst1q_s8(output_data + i, clamped);
1161 for (; i <
size; ++i)
1163 const int32_t input2_val = params.
input2_offset + input2_data[i];
1164 const int32_t unclamped_result =
1168 const int32_t clamped_output = std::min(
1170 output_data[i] =
static_cast<int8_t
>(clamped_output);
1174template <
typename T>
1175inline typename std::enable_if_t<is_quant8<T>::value>
1177 const T *input1_data,
const Shape &input2_shape,
const T *input2_data,
1184 return static_cast<T
>(
quant8_mul(params, a, b));
1191 params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data,
1199 const float *input1_data,
const Shape &input2_shape,
1206 const std::function<float(
const float &,
const float &)> fn =
1207 [](
const float &a,
const float &b) ->
float {
return a * b; };
1212 auto [implFunc1, implFunc2] = getBinaryOpWithActivationImplFloat<BinaryOpFuncMulFloat>(params);
1218 const float *input1_data,
const Shape &input2_shape,
const float *input2_data,
1223 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncDivFloat>(params);
1224 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
1226 const std::function<float(
const float &,
const float &)> fn =
1227 [](
const float &a,
const float &b) ->
float {
return a / b; };
1234 const float *input1_data,
const Shape &input2_shape,
1241 auto [implFunc1, implFunc2] = getBinaryOpWithActivationImplFloat<BinaryOpFuncDivFloat>(params);
1247 auto [implFunc1, implFunc2] =
1248 getBinaryOpWithActivationImplFloat<BinaryOpFuncSwapArgs<BinaryOpFuncDivFloat>>(params);
1255 const std::function<float(
const float &,
const float &)> fn =
1256 [](
const float &a,
const float &b) ->
float {
return a / b; };
const luci_interpreter::RuntimeShape output_shape
void BroadcastDivDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
std::enable_if_t< is_quant8< T >::value > Mul(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
void MulSimpleBroadcast(int size, const BinaryArithmeticOpParam ¶ms, const uint8_t broadcast_value, const uint8_t *input2_data, uint8_t *output_data)
std::enable_if_t< is_quant8< T >::value, int32_t > quant8_mul(const BinaryArithmeticOpParam ¶ms, const T input1_data, const T input2_data)
std::enable_if_t< is_quant8< T >::value, int32_t > quant8_sum(const BinaryArithmeticOpParam ¶ms, const T input1_data, const T input2_data)
void AddElementwise(int size, const BinaryArithmeticOpParam ¶ms, const uint8_t *input1_data, const uint8_t *input2_data, uint8_t *output_data)
void BinaryOpScalarBroadcast(int size, const BinaryArithmeticOpParam ¶ms, const float broadcast_value, const float *input2_data, float *output_data)
std::enable_if_t< is_quant8< T >::value > BroadcastAddDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
void Sub(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
void MulElementwise(int size, const BinaryArithmeticOpParam ¶ms, const uint8_t *input1_data, const uint8_t *input2_data, uint8_t *output_data)
void AddScalarBroadcast(int size, const BinaryArithmeticOpParam ¶ms, uint8_t broadcast_value, const uint8_t *input2_data, uint8_t *output_data)
std::enable_if_t< is_quant8< T >::value > BroadcastMulDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
void BroadcastSubDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
BinaryOpImplFloatFuncs getBinaryOpWithActivationImplFloat(const BinaryArithmeticOpParam ¶ms)
void BinaryOpElementwise(int size, const BinaryArithmeticOpParam ¶ms, const float *input1_data, const float *input2_data, float *output_data)
void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam ¶ms, bool switch_inputs, const Shape &, const T *unswitched_input1_data, const Shape &, const T *unswitched_input2_data, const Shape &, T *output_data, ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
void Div(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
std::pair< void(*)(int, const BinaryArithmeticOpParam &, const float *, const float *, float *), void(*)(int, const BinaryArithmeticOpParam &, const float, const float *, float *)> BinaryOpImplFloatFuncs
std::enable_if_t< is_quant8< T >::value > Add(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
std::enable_if_t< is_quant8< T >::value > BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data, const std::function< T(const BinaryArithmeticOpParam ¶ms, const T &, const T &)> &fn)
void BinaryArithmeticOp(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data, const std::function< T(const T &, const T &)> &fn)
@ kSecondInputBroadcastsFast
@ kFirstInputBroadcastsFast
int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0, const Shape &check_shape_1)
int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x, int32_t quantized_multiplier, int left_shift)
int32_t quantized_activation_max
BroadcastableOpCategory broadcast_category
int32_t input2_multiplier
int32_t quantized_activation_min
float float_activation_max
int32_t output_multiplier
int32_t input1_multiplier
float float_activation_min
static float applyFloor(const float value, const float floorParam)
static float applyCeiling(const float value, const float ceilingParam)
static float applyFloor(const float value, const float floorParam)
static float applyCeiling(const float value, const float ceilingParam)
static float applyFloor(const float value, const float floorParam)
static float applyCeiling(const float value, const float ceilingParam)
static float calculate(const float a, const float b)
static float calculate(const float a, const float b)
static float calculate(const float a, const float b)
static float calculate(const float a, const float b)
static T calculate(const T &a, const T &b)