ONE - On-device Neural Engine
Loading...
Searching...
No Matches
BinaryArithmeticOps.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
19#define __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
20
21#include <functional>
22#include <limits>
23#include <utility>
26#include "cker/Shape.h"
27#include "cker/Types.h"
28#include "cker/Utils.h"
29#include "fixedpoint/fixedpoint.h"
30
31namespace nnfw
32{
33namespace cker
34{
35namespace optimized
36{
37
38/* Old version: For Sub(float) and Div. */
39template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
40inline void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam &params, bool switch_inputs,
41 const Shape & /* unswitched_input1_shape */,
42 const T *unswitched_input1_data,
43 const Shape & /* unswitched_input2_shape */,
44 const T *unswitched_input2_data,
45 const Shape & /* output_shape */, T *output_data,
46 ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
47{
48 const T *input1_data = switch_inputs ? unswitched_input2_data : unswitched_input1_data;
49 const T *input2_data = switch_inputs ? unswitched_input1_data : unswitched_input2_data;
50
51 // Fivefold nested loops. The second input resets its position for each
52 // iteration of the second loop. The first input resets its position at the
53 // beginning of the fourth loop. The innermost loop is an elementwise add of
54 // sections of the arrays.
55 T *output_data_ptr = output_data;
56 const T *input1_data_ptr = input1_data;
57 const T *input2_data_reset = input2_data;
58 // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
59 // between input shapes. y3 for input 1 is always broadcast, and so the
60 // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
61 // Put another way,
62 // input1.shape.FlatSize = y0 * y1 * y2 * y4,
63 // input2.shape.FlatSize = y0 * y2 * y3 * y4.
64 int y0 = params.broadcast_shape[0];
65 int y1 = params.broadcast_shape[1];
66 int y2 = params.broadcast_shape[2];
67 int y3 = params.broadcast_shape[3];
68 int y4 = params.broadcast_shape[4];
69 if (y4 > 1)
70 {
71 // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
72 // dimension.
73 for (int i0 = 0; i0 < y0; ++i0)
74 {
75 const T *input2_data_ptr = nullptr;
76 for (int i1 = 0; i1 < y1; ++i1)
77 {
78 input2_data_ptr = input2_data_reset;
79 for (int i2 = 0; i2 < y2; ++i2)
80 {
81 for (int i3 = 0; i3 < y3; ++i3)
82 {
83 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
84 input2_data_ptr += y4;
85 output_data_ptr += y4;
86 }
87 // We have broadcast y4 of input1 data y3 times, and now move on.
88 input1_data_ptr += y4;
89 }
90 }
91 // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
92 input2_data_reset = input2_data_ptr;
93 }
94 }
95 else
96 {
97 // Special case of y4 == 1, in which the innermost loop is a single element
98 // and can be combined with the next (y3) as an inner broadcast.
99 //
100 // Note that this handles the case of pure scalar broadcast when
101 // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
102 // broadcast with batch (as y2 > 1).
103 //
104 // NOTE The process is the same as the above general case except simplified
105 // for y4 == 1 and the loop over y3 is contained within the
106 // AddScalarBroadcast function.
107 for (int i0 = 0; i0 < y0; ++i0)
108 {
109 const T *input2_data_ptr = nullptr;
110 for (int i1 = 0; i1 < y1; ++i1)
111 {
112 input2_data_ptr = input2_data_reset;
113 for (int i2 = 0; i2 < y2; ++i2)
114 {
115 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
116 input2_data_ptr += y3;
117 output_data_ptr += y3;
118 input1_data_ptr += 1;
119 }
120 }
121 input2_data_reset = input2_data_ptr;
122 }
123 }
124}
125
126// New version: For Mul, Add and Sub(quant8)
127template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
128inline void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam &unswitched_params,
129 const Shape & /* unswitched_input1_shape */,
130 const T *unswitched_input1_data,
131 const Shape & /* unswitched_input2_shape */,
132 const T *unswitched_input2_data,
133 const Shape & /* output_shape */, T *output_data,
134 ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
135{
136 BinaryArithmeticOpParam switched_params = unswitched_params;
137 switched_params.input1_offset = unswitched_params.input2_offset;
138 switched_params.input1_multiplier = unswitched_params.input2_multiplier;
139 switched_params.input1_shift = unswitched_params.input2_shift;
140 switched_params.input2_offset = unswitched_params.input1_offset;
141 switched_params.input2_multiplier = unswitched_params.input1_multiplier;
142 switched_params.input2_shift = unswitched_params.input1_shift;
143
144 const bool use_unswitched =
146
147 const BinaryArithmeticOpParam &params = use_unswitched ? unswitched_params : switched_params;
148 const T *input1_data = use_unswitched ? unswitched_input1_data : unswitched_input2_data;
149 const T *input2_data = use_unswitched ? unswitched_input2_data : unswitched_input1_data;
150
151 // Fivefold nested loops. The second input resets its position for each
152 // iteration of the second loop. The first input resets its position at the
153 // beginning of the fourth loop. The innermost loop is an elementwise add of
154 // sections of the arrays.
155 T *output_data_ptr = output_data;
156 const T *input1_data_ptr = input1_data;
157 const T *input2_data_reset = input2_data;
158 // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
159 // between input shapes. y3 for input 1 is always broadcast, and so the
160 // dimension there is 1, whereas optionally y1 might be broadcast for
161 // input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4,
162 // input2.shape.FlatSize = y0 * y2 * y3 * y4.
163 int y0 = params.broadcast_shape[0];
164 int y1 = params.broadcast_shape[1];
165 int y2 = params.broadcast_shape[2];
166 int y3 = params.broadcast_shape[3];
167 int y4 = params.broadcast_shape[4];
168 if (y4 > 1)
169 {
170 // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
171 // dimension.
172 for (int i0 = 0; i0 < y0; ++i0)
173 {
174 const T *input2_data_ptr = nullptr;
175 for (int i1 = 0; i1 < y1; ++i1)
176 {
177 input2_data_ptr = input2_data_reset;
178 for (int i2 = 0; i2 < y2; ++i2)
179 {
180 for (int i3 = 0; i3 < y3; ++i3)
181 {
182 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
183 input2_data_ptr += y4;
184 output_data_ptr += y4;
185 }
186 // We have broadcast y4 of input1 data y3 times, and now move on.
187 input1_data_ptr += y4;
188 }
189 }
190 // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
191 input2_data_reset = input2_data_ptr;
192 }
193 }
194 else
195 {
196 // Special case of y4 == 1, in which the innermost loop is a single
197 // element and can be combined with the next (y3) as an inner broadcast.
198 //
199 // Note that this handles the case of pure scalar broadcast when
200 // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
201 // broadcast with batch (as y2 > 1).
202 //
203 // NOTE The process is the same as the above general case except
204 // simplified for y4 == 1 and the loop over y3 is contained within the
205 // AddScalarBroadcast function.
206 for (int i0 = 0; i0 < y0; ++i0)
207 {
208 const T *input2_data_ptr = nullptr;
209 for (int i1 = 0; i1 < y1; ++i1)
210 {
211 input2_data_ptr = input2_data_reset;
212 for (int i2 = 0; i2 < y2; ++i2)
213 {
214 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
215 input2_data_ptr += y3;
216 output_data_ptr += y3;
217 input1_data_ptr += 1;
218 }
219 }
220 input2_data_reset = input2_data_ptr;
221 }
222 }
223}
224
225template <typename T>
226inline typename std::enable_if_t<is_quant8<T>::value, int32_t>
227quant8_sum(const BinaryArithmeticOpParam &params, const T input1_data, const T input2_data)
228{
229 const int32_t input1_val = params.input1_offset + input1_data;
230 const int32_t input2_val = params.input2_offset + input2_data;
231 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
232 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
233 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
234 shifted_input1_val, params.input1_multiplier, params.input1_shift);
235 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
236 shifted_input2_val, params.input2_multiplier, params.input2_shift);
237 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
238 const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
239 raw_sum, params.output_multiplier, params.output_shift) +
240 params.output_offset;
241 const int32_t clamped_output = std::min(params.quantized_activation_max,
242 std::max(params.quantized_activation_min, raw_output));
243 return clamped_output;
244}
245
246inline void AddElementwise(int size, const BinaryArithmeticOpParam &params,
247 const uint8_t *input1_data, const uint8_t *input2_data,
248 uint8_t *output_data)
249{
250 int i = 0;
251
252#ifdef USE_NEON
253 const uint8x8_t output_activation_min_vector = vdup_n_u8(params.quantized_activation_min);
254 const uint8x8_t output_activation_max_vector = vdup_n_u8(params.quantized_activation_max);
255 for (; i <= size - 8; i += 8)
256 {
257 const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
258 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
259 const int16x8_t input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
260 const int16x8_t input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
261 const int16x8_t input1_val = vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
262 const int16x8_t input2_val = vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
263 const int16x4_t input1_val_high = vget_high_s16(input1_val);
264 const int16x4_t input1_val_low = vget_low_s16(input1_val);
265 const int16x4_t input2_val_high = vget_high_s16(input2_val);
266 const int16x4_t input2_val_low = vget_low_s16(input2_val);
267 int32x4_t x11 = vmovl_s16(input1_val_low);
268 int32x4_t x12 = vmovl_s16(input1_val_high);
269 int32x4_t x21 = vmovl_s16(input2_val_low);
270 int32x4_t x22 = vmovl_s16(input2_val_high);
271 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
272 x11 = vshlq_s32(x11, left_shift_dup);
273 x12 = vshlq_s32(x12, left_shift_dup);
274 x21 = vshlq_s32(x21, left_shift_dup);
275 x22 = vshlq_s32(x22, left_shift_dup);
276 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
277 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
278 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
279 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
280 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
281 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
282 x11 = vshlq_s32(x11, input1_shift_dup);
283 x12 = vshlq_s32(x12, input1_shift_dup);
284 x21 = vshlq_s32(x21, input2_shift_dup);
285 x22 = vshlq_s32(x22, input2_shift_dup);
286 int32x4_t s1 = vaddq_s32(x11, x21);
287 int32x4_t s2 = vaddq_s32(x12, x22);
288 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
289 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
290 using gemmlowp::RoundingDivideByPOT;
291 s1 = RoundingDivideByPOT(s1, -params.output_shift);
292 s2 = RoundingDivideByPOT(s2, -params.output_shift);
293 const int16x4_t s1_narrowed = vmovn_s32(s1);
294 const int16x4_t s2_narrowed = vmovn_s32(s2);
295 const int16x8_t s =
296 vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(params.output_offset));
297 const uint8x8_t clamped =
298 vmax_u8(output_activation_min_vector, vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
299 vst1_u8(output_data + i, clamped);
300 }
301#endif // NEON
302 for (; i < size; ++i)
303 {
304 const int32_t input1_val = params.input1_offset + input1_data[i];
305 const int32_t input2_val = params.input2_offset + input2_data[i];
306 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
307 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
308 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
309 shifted_input1_val, params.input1_multiplier, params.input1_shift);
310 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
311 shifted_input2_val, params.input2_multiplier, params.input2_shift);
312 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
313 const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
314 raw_sum, params.output_multiplier, params.output_shift) +
315 params.output_offset;
316 const int32_t clamped_output = std::min(params.quantized_activation_max,
317 std::max(params.quantized_activation_min, raw_output));
318 output_data[i] = static_cast<uint8_t>(clamped_output);
319 }
320}
321
322inline void AddElementwise(int size, const BinaryArithmeticOpParam &params,
323 const int8_t *input1_data, const int8_t *input2_data,
324 int8_t *output_data)
325{
326 int i = 0;
327#ifdef USE_NEON
328 const int8x16_t output_activation_min_vector = vdupq_n_s8(params.quantized_activation_min);
329 const int8x16_t output_activation_max_vector = vdupq_n_s8(params.quantized_activation_max);
330
331 const int input1_left_shift = params.left_shift + params.input1_shift;
332 const int input2_left_shift = params.left_shift + params.input2_shift;
333 const int32x4_t input1_left_dup = vdupq_n_s32(input1_left_shift);
334 const int32x4_t input2_left_dup = vdupq_n_s32(input2_left_shift);
335
336 const int16x8_t input1_offset_dup = vdupq_n_s16(params.input1_offset);
337 const int16x8_t input2_offset_dup = vdupq_n_s16(params.input2_offset);
338
339 for (; i <= size - 16; i += 16)
340 {
341 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
342 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
343
344 const int16x8_t input1_val_s16_high = vmovl_s8(vget_high_s8(input1_val_original));
345 const int16x8_t input1_val_s16_low = vmovl_s8(vget_low_s8(input1_val_original));
346
347 const int16x8_t input2_val_s16_high = vmovl_s8(vget_high_s8(input2_val_original));
348 const int16x8_t input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
349 const int16x8_t input1_val_high = vaddq_s16(input1_val_s16_high, input1_offset_dup);
350 const int16x8_t input2_val_high = vaddq_s16(input2_val_s16_high, input2_offset_dup);
351 const int16x8_t input1_val_low = vaddq_s16(input1_val_s16_low, input1_offset_dup);
352 const int16x8_t input2_val_low = vaddq_s16(input2_val_s16_low, input2_offset_dup);
353 const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
354 const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
355 const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
356 const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
357 const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
358 const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
359 const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
360 const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
361 int32x4_t x111 = vmovl_s16(input1_val_low_low);
362 int32x4_t x112 = vmovl_s16(input1_val_low_high);
363 int32x4_t x121 = vmovl_s16(input1_val_high_low);
364 int32x4_t x122 = vmovl_s16(input1_val_high_high);
365 int32x4_t x211 = vmovl_s16(input2_val_low_low);
366 int32x4_t x212 = vmovl_s16(input2_val_low_high);
367 int32x4_t x221 = vmovl_s16(input2_val_high_low);
368 int32x4_t x222 = vmovl_s16(input2_val_high_high);
369
370 x111 = vshlq_s32(x111, input1_left_dup);
371 x112 = vshlq_s32(x112, input1_left_dup);
372 x121 = vshlq_s32(x121, input1_left_dup);
373 x122 = vshlq_s32(x122, input1_left_dup);
374 x211 = vshlq_s32(x211, input2_left_dup);
375 x212 = vshlq_s32(x212, input2_left_dup);
376 x221 = vshlq_s32(x221, input2_left_dup);
377 x222 = vshlq_s32(x222, input2_left_dup);
378 x111 = vqrdmulhq_n_s32(x111, params.input1_multiplier);
379 x112 = vqrdmulhq_n_s32(x112, params.input1_multiplier);
380 x121 = vqrdmulhq_n_s32(x121, params.input1_multiplier);
381 x122 = vqrdmulhq_n_s32(x122, params.input1_multiplier);
382 x211 = vqrdmulhq_n_s32(x211, params.input2_multiplier);
383 x212 = vqrdmulhq_n_s32(x212, params.input2_multiplier);
384 x221 = vqrdmulhq_n_s32(x221, params.input2_multiplier);
385 x222 = vqrdmulhq_n_s32(x222, params.input2_multiplier);
386 int32x4_t s11 = vaddq_s32(x111, x211);
387 int32x4_t s12 = vaddq_s32(x112, x212);
388 int32x4_t s21 = vaddq_s32(x121, x221);
389 int32x4_t s22 = vaddq_s32(x122, x222);
390 s11 = vqrdmulhq_n_s32(s11, params.output_multiplier);
391 s12 = vqrdmulhq_n_s32(s12, params.output_multiplier);
392 s21 = vqrdmulhq_n_s32(s21, params.output_multiplier);
393 s22 = vqrdmulhq_n_s32(s22, params.output_multiplier);
394 using gemmlowp::RoundingDivideByPOT;
395 s11 = RoundingDivideByPOT(s11, -params.output_shift);
396 s12 = RoundingDivideByPOT(s12, -params.output_shift);
397 s21 = RoundingDivideByPOT(s21, -params.output_shift);
398 s22 = RoundingDivideByPOT(s22, -params.output_shift);
399 const int16x4_t s11_narrowed = vmovn_s32(s11);
400 const int16x4_t s12_narrowed = vmovn_s32(s12);
401 const int16x4_t s21_narrowed = vmovn_s32(s21);
402 const int16x4_t s22_narrowed = vmovn_s32(s22);
403 const int16x8_t s1 =
404 vaddq_s16(vcombine_s16(s11_narrowed, s12_narrowed), vdupq_n_s16(params.output_offset));
405 const int16x8_t s2 =
406 vaddq_s16(vcombine_s16(s21_narrowed, s22_narrowed), vdupq_n_s16(params.output_offset));
407 const int8x16_t s = vcombine_s8(vqmovn_s16(s1), vqmovn_s16(s2));
408
409 const int8x16_t clamped =
410 vmaxq_s8(output_activation_min_vector, vminq_s8(output_activation_max_vector, s));
411 vst1q_s8(output_data + i, clamped);
412 }
413#endif // NEON
414
415 for (; i < size; ++i)
416 {
417 const int32_t input1_val = params.input1_offset + input1_data[i];
418 const int32_t input2_val = params.input2_offset + input2_data[i];
419 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
420 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
421 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
422 shifted_input1_val, params.input1_multiplier, params.input1_shift);
423 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
424 shifted_input2_val, params.input2_multiplier, params.input2_shift);
425 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
426 const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
427 raw_sum, params.output_multiplier, params.output_shift) +
428 params.output_offset;
429 const int32_t clamped_output = std::min(params.quantized_activation_max,
430 std::max(params.quantized_activation_min, raw_output));
431 output_data[i] = static_cast<int8_t>(clamped_output);
432 }
433}
434
436{
437#ifdef USE_NEON
438 static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
439 {
440 return vaddq_f32(a, b);
441 }
442#endif // USE_NEON
443 static inline float calculate(const float a, const float b) { return a + b; }
444};
445
447{
448#ifdef USE_NEON
449 static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
450 {
451 return vsubq_f32(a, b);
452 }
453#endif // USE_NEON
454 static inline float calculate(const float a, const float b) { return a - b; }
455};
456
458{
459#ifdef USE_NEON
460 static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
461 {
462 return vmulq_f32(a, b);
463 }
464#endif // USE_NEON
465 static inline float calculate(const float a, const float b) { return a * b; }
466};
467
469{
470#ifdef USE_NEON
471#ifdef __aarch64__
472 static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
473 {
474 return vdivq_f32(a, b);
475 }
476#endif // __aarch64__
477#endif // USE_NEON
478 static inline float calculate(const float a, const float b) { return a / b; }
479};
480
481template <class BASEOPERATOR> struct BinaryOpFuncSwapArgs
482{
483 template <typename T> static inline T calculate(const T &a, const T &b)
484 {
485 return BASEOPERATOR::calculate(b, a);
486 }
487};
488
490{
491#ifdef USE_NEON
492 static inline float32x4_t applyCeiling(const float32x4_t &value, const float32x4_t &ceilingParam)
493 {
494 (void)ceilingParam; // suppress unused argument warning
495 return value;
496 }
497 static inline float32x4_t applyFloor(const float32x4_t &value, const float32x4_t &floorParam)
498 {
499 (void)floorParam;
500 return value;
501 }
502#endif // USE_NEON
503 static inline float applyCeiling(const float value, const float ceilingParam)
504 {
505 (void)ceilingParam;
506 return value;
507 }
508 static inline float applyFloor(const float value, const float floorParam)
509 {
510 (void)floorParam;
511 return value;
512 }
513};
514
516{
517#ifdef USE_NEON
518 static inline float32x4_t applyCeiling(const float32x4_t &value, const float32x4_t &ceilingParam)
519 {
520 (void)ceilingParam; // suppress unused argument warning
521 return value;
522 }
523 static inline float32x4_t applyFloor(const float32x4_t &value, const float32x4_t &floorParam)
524 {
525 return vmaxq_f32(value, floorParam);
526 }
527#endif // USE_NEON
528 static inline float applyCeiling(const float value, const float ceilingParam)
529 {
530 (void)ceilingParam;
531 return value;
532 }
533 static inline float applyFloor(const float value, const float floorParam)
534 {
535 return std::max(value, floorParam);
536 }
537};
538
540{
541#ifdef USE_NEON
542 static inline float32x4_t applyCeiling(const float32x4_t &value, const float32x4_t &ceilingParam)
543 {
544 return vminq_f32(value, ceilingParam);
545 }
546 static inline float32x4_t applyFloor(const float32x4_t &value, const float32x4_t &floorParam)
547 {
548 return vmaxq_f32(value, floorParam);
549 }
550#endif // USE_NEON
551 static inline float applyCeiling(const float value, const float ceilingParam)
552 {
553 return std::min(value, ceilingParam);
554 }
555 static inline float applyFloor(const float value, const float floorParam)
556 {
557 return std::max(value, floorParam);
558 }
559};
560
561template <class OPERATOR, class ACTIVATION>
562inline void BinaryOpElementwise(int size, const BinaryArithmeticOpParam &params,
563 const float *input1_data, const float *input2_data,
564 float *output_data)
565{
566 int i = 0;
567
568#ifdef USE_NEON
569 const auto activation_min = vdupq_n_f32(params.float_activation_min);
570 const auto activation_max = vdupq_n_f32(params.float_activation_max);
571 for (; i <= size - 16; i += 16)
572 {
573 auto a10 = vld1q_f32(input1_data + i);
574 auto a11 = vld1q_f32(input1_data + i + 4);
575 auto a12 = vld1q_f32(input1_data + i + 8);
576 auto a13 = vld1q_f32(input1_data + i + 12);
577 auto a20 = vld1q_f32(input2_data + i);
578 auto a21 = vld1q_f32(input2_data + i + 4);
579 auto a22 = vld1q_f32(input2_data + i + 8);
580 auto a23 = vld1q_f32(input2_data + i + 12);
581 auto x0 = OPERATOR::calculate(a10, a20);
582 auto x1 = OPERATOR::calculate(a11, a21);
583 auto x2 = OPERATOR::calculate(a12, a22);
584 auto x3 = OPERATOR::calculate(a13, a23);
585 x0 = ACTIVATION::applyFloor(x0, activation_min);
586 x1 = ACTIVATION::applyFloor(x1, activation_min);
587 x2 = ACTIVATION::applyFloor(x2, activation_min);
588 x3 = ACTIVATION::applyFloor(x3, activation_min);
589 x0 = ACTIVATION::applyCeiling(x0, activation_max);
590 x1 = ACTIVATION::applyCeiling(x1, activation_max);
591 x2 = ACTIVATION::applyCeiling(x2, activation_max);
592 x3 = ACTIVATION::applyCeiling(x3, activation_max);
593 vst1q_f32(output_data + i, x0);
594 vst1q_f32(output_data + i + 4, x1);
595 vst1q_f32(output_data + i + 8, x2);
596 vst1q_f32(output_data + i + 12, x3);
597 }
598 for (; i <= size - 4; i += 4)
599 {
600 auto a1 = vld1q_f32(input1_data + i);
601 auto a2 = vld1q_f32(input2_data + i);
602 auto x = OPERATOR::calculate(a1, a2); // vaddq
603 auto x_clamped =
604 ACTIVATION::applyCeiling(ACTIVATION::applyFloor(x, activation_min), activation_max);
605 vst1q_f32(output_data + i, x_clamped);
606 }
607#endif // USE_NEON
608 for (; i < size; i++)
609 {
610 auto x = OPERATOR::calculate(input1_data[i], input2_data[i]);
611 output_data[i] = ACTIVATION::applyCeiling(
612 ACTIVATION::applyFloor(x, params.float_activation_min), params.float_activation_max);
613 }
614}
615
616// Broadcast binary op template that can often be used for inner loop
617// This function will handle scalar_value (LHS) and vector_values (RHS).
618// Since it's a float function, input params does not matter here.
619template <class OPERATOR, class ACTIVATION>
621 const float broadcast_value, const float *input2_data,
622 float *output_data)
623{
624 int i = 0;
625
626#ifdef USE_NEON
627 const auto activation_min = vdupq_n_f32(params.float_activation_min);
628 const auto activation_max = vdupq_n_f32(params.float_activation_max);
629 const auto broadcast_value_dup = vdupq_n_f32(broadcast_value);
630 for (; i <= size - 16; i += 16)
631 {
632 auto a20 = vld1q_f32(input2_data + i);
633 auto a21 = vld1q_f32(input2_data + i + 4);
634 auto a22 = vld1q_f32(input2_data + i + 8);
635 auto a23 = vld1q_f32(input2_data + i + 12);
636 auto x0 = OPERATOR::calculate(broadcast_value_dup, a20);
637 auto x1 = OPERATOR::calculate(broadcast_value_dup, a21);
638 auto x2 = OPERATOR::calculate(broadcast_value_dup, a22);
639 auto x3 = OPERATOR::calculate(broadcast_value_dup, a23);
640 x0 = ACTIVATION::applyFloor(x0, activation_min);
641 x1 = ACTIVATION::applyFloor(x1, activation_min);
642 x2 = ACTIVATION::applyFloor(x2, activation_min);
643 x3 = ACTIVATION::applyFloor(x3, activation_min);
644 x0 = ACTIVATION::applyCeiling(x0, activation_max);
645 x1 = ACTIVATION::applyCeiling(x1, activation_max);
646 x2 = ACTIVATION::applyCeiling(x2, activation_max);
647 x3 = ACTIVATION::applyCeiling(x3, activation_max);
648 vst1q_f32(output_data + i, x0);
649 vst1q_f32(output_data + i + 4, x1);
650 vst1q_f32(output_data + i + 8, x2);
651 vst1q_f32(output_data + i + 12, x3);
652 }
653 for (; i <= size - 4; i += 4)
654 {
655 auto a2 = vld1q_f32(input2_data + i);
656 auto x = OPERATOR::calculate(broadcast_value_dup, a2);
657 auto x_clamped =
658 ACTIVATION::applyCeiling(ACTIVATION::applyFloor(x, activation_min), activation_max);
659 vst1q_f32(output_data + i, x_clamped);
660 }
661#endif // USE_NEON
662 for (; i < size; i++)
663 {
664 auto x = OPERATOR::calculate(broadcast_value, input2_data[i]);
665 output_data[i] = ACTIVATION::applyCeiling(
666 ACTIVATION::applyFloor(x, params.float_activation_min), params.float_activation_max);
667 }
668}
669
671 std::pair<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *, float *),
672 void (*)(int, const BinaryArithmeticOpParam &, const float, const float *, float *)>;
673
674template <class FUNC>
677{
678 if (params.float_activation_max == std::numeric_limits<float>::max())
679 if (params.float_activation_min == std::numeric_limits<float>::lowest())
680 return BinaryOpImplFloatFuncs(BinaryOpElementwise<FUNC, BinaryOpActivationFloatNone>,
681 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatNone>);
682 else
683 return BinaryOpImplFloatFuncs(BinaryOpElementwise<FUNC, BinaryOpActivationFloatMax>,
684 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatMax>);
685 else
686 return BinaryOpImplFloatFuncs(BinaryOpElementwise<FUNC, BinaryOpActivationFloatMinMax>,
687 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatMinMax>);
688}
689
690template <typename T>
691inline typename std::enable_if_t<is_quant8<T>::value>
692Add(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data,
693 const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
694{
695 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
696 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
697}
698
699inline void Add(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
700 const float *input1_data, const Shape &input2_shape, const float *input2_data,
701 const Shape &output_shape, float *output_data)
702{
703 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
704 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncAddFloat>(params);
705 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
706}
707
708// Scalar-broadcast add that can be used for inner loop of more general
709// broadcast add, so that, for example, scalar-broadcast with batch will still
710// be fast.
711inline void AddScalarBroadcast(int size, const BinaryArithmeticOpParam &params,
712 uint8_t broadcast_value, const uint8_t *input2_data,
713 uint8_t *output_data)
714{
715 int i = 0;
716 int32_t clamped_output;
717 for (; i < size; ++i)
718 {
719 clamped_output = quant8_sum(params, broadcast_value, input2_data[i]);
720 output_data[i] = static_cast<uint8_t>(clamped_output);
721 }
722}
723
724// Scalar-broadcast add that can be used for inner loop of more general
725// broadcast add, so that, for example, scalar-broadcast with batch will still
726// be fast.
727inline void AddScalarBroadcast(int size, const BinaryArithmeticOpParam &params, int8_t input1_data,
728 const int8_t *input2_data, int8_t *output_data)
729{
730 using gemmlowp::RoundingDivideByPOT;
731 int i = 0;
732#ifdef USE_NEON
733 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
734 const int8x8_t output_activation_min_vector = vdup_n_s8(params.quantized_activation_min);
735 const int8x8_t output_activation_max_vector = vdup_n_s8(params.quantized_activation_max);
736
737 // Process broadcast scalar.
738 const int8x8_t input1_val_original = vdup_n_s8(input1_data);
739 const int16x8_t input1_val_s16 = vmovl_s8(input1_val_original);
740 const int16x8_t input1_val = vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
741 const int16x4_t input1_val_high = vget_high_s16(input1_val);
742 const int16x4_t input1_val_low = vget_low_s16(input1_val);
743 int32x4_t x11 = vmovl_s16(input1_val_low);
744 int32x4_t x12 = vmovl_s16(input1_val_high);
745 x11 = vshlq_s32(x11, left_shift_dup);
746 x12 = vshlq_s32(x12, left_shift_dup);
747 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
748 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
749 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
750 x11 = vshlq_s32(x11, input1_shift_dup);
751 x12 = vshlq_s32(x12, input1_shift_dup);
752
753 for (; i <= size - 8; i += 8)
754 {
755 const int8x8_t input2_val_original = vld1_s8(input2_data + i);
756 const int16x8_t input2_val_s16 = vmovl_s8(input2_val_original);
757 const int16x8_t input2_val = vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
758 const int16x4_t input2_val_high = vget_high_s16(input2_val);
759 const int16x4_t input2_val_low = vget_low_s16(input2_val);
760 int32x4_t x21 = vmovl_s16(input2_val_low);
761 int32x4_t x22 = vmovl_s16(input2_val_high);
762 x21 = vshlq_s32(x21, left_shift_dup);
763 x22 = vshlq_s32(x22, left_shift_dup);
764 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
765 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
766 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
767 x21 = vshlq_s32(x21, input2_shift_dup);
768 x22 = vshlq_s32(x22, input2_shift_dup);
769 int32x4_t s1 = vaddq_s32(x11, x21);
770 int32x4_t s2 = vaddq_s32(x12, x22);
771 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
772 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
773 s1 = RoundingDivideByPOT(s1, -params.output_shift);
774 s2 = RoundingDivideByPOT(s2, -params.output_shift);
775 const int16x4_t s1_narrowed = vmovn_s32(s1);
776 const int16x4_t s2_narrowed = vmovn_s32(s2);
777 const int16x8_t s =
778 vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(params.output_offset));
779 const int8x8_t clamped =
780 vmax_s8(output_activation_min_vector, vmin_s8(output_activation_max_vector, vqmovn_s16(s)));
781 vst1_s8(output_data + i, clamped);
782 }
783#endif // NEON
784
785 if (i < size)
786 {
787 // Process broadcast scalar.
788 const int32_t input1_val = params.input1_offset + input1_data;
789 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
790 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
791 shifted_input1_val, params.input1_multiplier, params.input1_shift);
792
793 for (; i < size; ++i)
794 {
795 const int32_t input2_val = params.input2_offset + input2_data[i];
796 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
797 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
798 shifted_input2_val, params.input2_multiplier, params.input2_shift);
799 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
800 const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
801 raw_sum, params.output_multiplier, params.output_shift) +
802 params.output_offset;
803 const int32_t clamped_output = std::min(
804 params.quantized_activation_max, std::max(params.quantized_activation_min, raw_output));
805 output_data[i] = static_cast<int8_t>(clamped_output);
806 }
807 }
808}
809
810template <typename T>
811inline typename std::enable_if_t<is_quant8<T>::value>
812BroadcastAddDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
813 const T *input1_data, const Shape &input2_shape, const T *input2_data,
814 const Shape &output_shape, T *output_data)
815{
817 {
818 const std::function<T(const BinaryArithmeticOpParam &, const T &, const T &)> fn =
819 [](const BinaryArithmeticOpParam &params, const T &a, const T &b) {
820 return static_cast<T>(quant8_sum(params, a, b));
821 };
822 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
823 input2_data, output_shape, output_data, fn);
824 return;
825 }
826
828 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
829 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const T *, const T *, T *)>(
831 static_cast<void (*)(int, const BinaryArithmeticOpParam &, T, const T *, T *)>(
833}
834
835inline void BroadcastAddDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
836 const float *input1_data, const Shape &input2_shape,
837 const float *input2_data, const Shape &output_shape,
838 float *output_data)
839{
841 {
842 const std::function<float(const float &, const float &)> fn =
843 [](const float &a, const float &b) -> float { return a + b; };
844 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
845 input2_data, output_shape, output_data, fn);
846 }
847 else
848 {
849 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncAddFloat>(params);
850
853 input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
854 implFuncs.first, implFuncs.second);
855 }
856}
857
858inline void Sub(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
859 const float *input1_data, const Shape &input2_shape, const float *input2_data,
860 const Shape &output_shape, float *output_data)
861{
862 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
863 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncSubFloat>(params);
864 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
865}
866
867inline void BroadcastSubDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
868 const float *input1_data, const Shape &input2_shape,
869 const float *input2_data, const Shape &output_shape,
870 float *output_data)
871{
873 {
874 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncSubFloat>(params);
875 BinaryBroadcastFiveFold(params, false, input1_shape, input1_data, input2_shape, input2_data,
876 output_shape, output_data, implFuncs.first, implFuncs.second);
877 }
879 {
880 auto implFuncs =
881 getBinaryOpWithActivationImplFloat<BinaryOpFuncSwapArgs<BinaryOpFuncSubFloat>>(params);
882 BinaryBroadcastFiveFold(params, true, input1_shape, input1_data, input2_shape, input2_data,
883 output_shape, output_data, implFuncs.first, implFuncs.second);
884 }
885 else
886 {
887 const std::function<float(const float &, const float &)> fn =
888 [](const float &a, const float &b) -> float { return a - b; };
889 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
890 input2_data, output_shape, output_data, fn);
891 }
892}
893
894template <typename T>
895inline typename std::enable_if_t<is_quant8<T>::value, int32_t>
896quant8_mul(const BinaryArithmeticOpParam &params, const T input1_data, const T input2_data)
897{
898 const int32_t input1_val = params.input1_offset + input1_data;
899 const int32_t input2_val = params.input2_offset + input2_data;
900 const int32_t unclamped_result =
901 params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
902 params.output_multiplier,
903 params.output_shift);
904 const int32_t clamped_output = std::min(
905 params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
906
907 return clamped_output;
908}
909
910inline void MulElementwise(int size, const BinaryArithmeticOpParam &params,
911 const uint8_t *input1_data, const uint8_t *input2_data,
912 uint8_t *output_data)
913{
914 int i = 0;
915
916#ifdef USE_NEON
917 const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
918 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
919 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
920 const auto output_activation_min_vector = vdup_n_u8(params.quantized_activation_min);
921 const auto output_activation_max_vector = vdup_n_u8(params.quantized_activation_max);
922 const int left_shift = std::max(0, params.output_shift);
923 const int right_shift = std::max(0, -params.output_shift);
924 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
925 for (; i <= size - 8; i += 8)
926 {
927 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
928 const auto input1_val_original = vld1_u8(input1_data + i);
929 const auto input2_val_original = vld1_u8(input2_data + i);
930 const auto input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
931 const auto input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
932 const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
933 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
934
935 const auto input1_val_low = vget_low_s16(input1_val);
936 const auto input1_val_high = vget_high_s16(input1_val);
937 const auto input2_val_low = vget_low_s16(input2_val);
938 const auto input2_val_high = vget_high_s16(input2_val);
939
940 auto p1 = vmull_s16(input2_val_low, input1_val_low);
941 auto p2 = vmull_s16(input2_val_high, input1_val_high);
942
943 p1 = vshlq_s32(p1, left_shift_vec);
944 p2 = vshlq_s32(p2, left_shift_vec);
945 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
946 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
947 using gemmlowp::RoundingDivideByPOT;
948 p1 = RoundingDivideByPOT(p1, right_shift);
949 p2 = RoundingDivideByPOT(p2, right_shift);
950
951 const auto p1_narrowed = vqmovn_s32(p1);
952 const auto p2_narrowed = vqmovn_s32(p2);
953 const auto p = vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
954 const auto clamped =
955 vmax_u8(output_activation_min_vector, vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
956 vst1_u8(output_data + i, clamped);
957 }
958#endif // NEON
959
960 for (; i < size; ++i)
961 {
962 const int32_t input1_val = params.input1_offset + input1_data[i];
963 const int32_t input2_val = params.input2_offset + input2_data[i];
964 const int32_t unclamped_result =
965 params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
966 params.output_multiplier,
967 params.output_shift);
968 const int32_t clamped_output = std::min(
969 params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
970 output_data[i] = static_cast<uint8_t>(clamped_output);
971 }
972}
973
974inline void MulElementwise(int size, const BinaryArithmeticOpParam &params,
975 const int8_t *input1_data, const int8_t *input2_data,
976 int8_t *output_data)
977{
978 int i = 0;
979#ifdef USE_NEON
980 const int16x8_t input1_offset_vector = vdupq_n_s16(params.input1_offset);
981 const int16x8_t input2_offset_vector = vdupq_n_s16(params.input2_offset);
982 const int16x8_t output_offset_vector = vdupq_n_s16(params.output_offset);
983 const auto output_activation_min_vector = vdupq_n_s8(params.quantized_activation_min);
984 const auto output_activation_max_vector = vdupq_n_s8(params.quantized_activation_max);
985 const int left_shift = std::max(0, params.output_shift);
986 const int right_shift = std::max(0, -params.output_shift);
987 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
988 for (; i <= size - 16; i += 16)
989 {
990 // We load / store 16 at a time, multiplying as four sets of 4 int32s.
991 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
992 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
993
994 const int16x8_t input1_val_s16_high = vmovl_s8(vget_high_s8(input1_val_original));
995 const int16x8_t input1_val_s16_low = vmovl_s8(vget_low_s8(input1_val_original));
996
997 const int16x8_t input2_val_s16_high = vmovl_s8(vget_high_s8(input2_val_original));
998 const int16x8_t input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
999 const int16x8_t input1_val_high = vaddq_s16(input1_val_s16_high, input1_offset_vector);
1000 const int16x8_t input2_val_high = vaddq_s16(input2_val_s16_high, input2_offset_vector);
1001 const int16x8_t input1_val_low = vaddq_s16(input1_val_s16_low, input1_offset_vector);
1002 const int16x8_t input2_val_low = vaddq_s16(input2_val_s16_low, input2_offset_vector);
1003 const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
1004 const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
1005 const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
1006 const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
1007 const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
1008 const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
1009 const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
1010 const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
1011
1012 auto p1 = vmull_s16(input2_val_high_high, input1_val_high_high);
1013 auto p2 = vmull_s16(input2_val_high_low, input1_val_high_low);
1014 auto p3 = vmull_s16(input2_val_low_high, input1_val_low_high);
1015 auto p4 = vmull_s16(input2_val_low_low, input1_val_low_low);
1016
1017 p1 = vshlq_s32(p1, left_shift_vec);
1018 p2 = vshlq_s32(p2, left_shift_vec);
1019 p3 = vshlq_s32(p3, left_shift_vec);
1020 p4 = vshlq_s32(p4, left_shift_vec);
1021
1022 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
1023 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
1024 p3 = vqrdmulhq_n_s32(p3, params.output_multiplier);
1025 p4 = vqrdmulhq_n_s32(p4, params.output_multiplier);
1026 using gemmlowp::RoundingDivideByPOT;
1027 p1 = RoundingDivideByPOT(p1, right_shift);
1028 p2 = RoundingDivideByPOT(p2, right_shift);
1029 p3 = RoundingDivideByPOT(p3, right_shift);
1030 p4 = RoundingDivideByPOT(p4, right_shift);
1031
1032 const auto p1_narrowed = vqmovn_s32(p1);
1033 const auto p2_narrowed = vqmovn_s32(p2);
1034 const auto p3_narrowed = vqmovn_s32(p3);
1035 const auto p4_narrowed = vqmovn_s32(p4);
1036
1037 const int16x8_t p_part1 =
1038 vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
1039 const int16x8_t p_part2 =
1040 vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
1041 const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
1042
1043 const auto clamped =
1044 vmaxq_s8(output_activation_min_vector, vminq_s8(output_activation_max_vector, p));
1045 vst1q_s8(output_data + i, clamped);
1046 }
1047#endif // NEON
1048
1049 for (; i < size; ++i)
1050 {
1051 const int32_t input1_val = params.input1_offset + input1_data[i];
1052 const int32_t input2_val = params.input2_offset + input2_data[i];
1053 const int32_t unclamped_result =
1054 params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
1055 params.output_multiplier,
1056 params.output_shift);
1057 const int32_t clamped_output = std::min(
1058 params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
1059 output_data[i] = static_cast<int8_t>(clamped_output);
1060 }
1061}
1062
1063template <typename T>
1064inline typename std::enable_if_t<is_quant8<T>::value>
1065Mul(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data,
1066 const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
1067{
1068 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
1069 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
1070}
1071
1072inline void Mul(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
1073 const float *input1_data, const Shape &input2_shape, const float *input2_data,
1074 const Shape &output_shape, float *output_data)
1075{
1076 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
1077 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncMulFloat>(params);
1078 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
1079}
1080
1081inline void MulSimpleBroadcast(int size, const BinaryArithmeticOpParam &params,
1082 const uint8_t broadcast_value, const uint8_t *input2_data,
1083 uint8_t *output_data)
1084{
1085 int i = 0;
1086 int32_t clamped_output;
1087 for (; i < size; ++i)
1088 {
1089 clamped_output = quant8_mul(params, broadcast_value, input2_data[i]);
1090 output_data[i] = static_cast<uint8_t>(clamped_output);
1091 }
1092}
1093
1094// Broadcast mul that can often be used for inner loop of broadcast Mul.
1095inline void MulSimpleBroadcast(int size, const BinaryArithmeticOpParam &params,
1096 const int8_t broadcast_value, const int8_t *input2_data,
1097 int8_t *output_data)
1098{
1099 const int16_t input1_val = params.input1_offset + broadcast_value;
1100
1101 int i = 0;
1102#ifdef USE_NEON
1103 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
1104 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
1105 const auto output_activation_min_vector = vdupq_n_s8(params.quantized_activation_min);
1106 const auto output_activation_max_vector = vdupq_n_s8(params.quantized_activation_max);
1107 const int left_shift = std::max(0, params.output_shift);
1108 const int right_shift = std::max(0, -params.output_shift);
1109 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
1110 for (; i <= size - 16; i += 16)
1111 {
1112 // We load / store 16 at a time, multiplying as four sets of 4 int32s.
1113 const auto input2_val_original = vld1q_s8(input2_data + i);
1114 const auto input2_val_s16_high = vmovl_s8(vget_high_s8(input2_val_original));
1115 const auto input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
1116
1117 const auto input2_val_high = vaddq_s16(input2_val_s16_high, input2_offset_vector);
1118 const auto input2_val_low = vaddq_s16(input2_val_s16_low, input2_offset_vector);
1119
1120 const auto input2_val_low_low = vget_low_s16(input2_val_low);
1121 const auto input2_val_low_high = vget_high_s16(input2_val_low);
1122 const auto input2_val_high_low = vget_low_s16(input2_val_high);
1123 const auto input2_val_high_high = vget_high_s16(input2_val_high);
1124
1125 auto p1 = vmull_n_s16(input2_val_high_high, input1_val);
1126 auto p2 = vmull_n_s16(input2_val_high_low, input1_val);
1127 auto p3 = vmull_n_s16(input2_val_low_high, input1_val);
1128 auto p4 = vmull_n_s16(input2_val_low_low, input1_val);
1129
1130 p1 = vshlq_s32(p1, left_shift_vec);
1131 p2 = vshlq_s32(p2, left_shift_vec);
1132 p3 = vshlq_s32(p3, left_shift_vec);
1133 p4 = vshlq_s32(p4, left_shift_vec);
1134
1135 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
1136 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
1137 p3 = vqrdmulhq_n_s32(p3, params.output_multiplier);
1138 p4 = vqrdmulhq_n_s32(p4, params.output_multiplier);
1139 using gemmlowp::RoundingDivideByPOT;
1140 p1 = RoundingDivideByPOT(p1, right_shift);
1141 p2 = RoundingDivideByPOT(p2, right_shift);
1142 p3 = RoundingDivideByPOT(p3, right_shift);
1143 p4 = RoundingDivideByPOT(p4, right_shift);
1144
1145 const auto p1_narrowed = vqmovn_s32(p1);
1146 const auto p2_narrowed = vqmovn_s32(p2);
1147 const auto p3_narrowed = vqmovn_s32(p3);
1148 const auto p4_narrowed = vqmovn_s32(p4);
1149
1150 const int16x8_t p_part1 =
1151 vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
1152 const int16x8_t p_part2 =
1153 vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
1154 const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
1155
1156 const auto clamped =
1157 vmaxq_s8(output_activation_min_vector, vminq_s8(output_activation_max_vector, p));
1158 vst1q_s8(output_data + i, clamped);
1159 }
1160#endif // NEON
1161
1162 for (; i < size; ++i)
1163 {
1164 const int32_t input2_val = params.input2_offset + input2_data[i];
1165 const int32_t unclamped_result =
1166 params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
1167 params.output_multiplier,
1168 params.output_shift);
1169 const int32_t clamped_output = std::min(
1170 params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
1171 output_data[i] = static_cast<int8_t>(clamped_output);
1172 }
1173}
1174
1175template <typename T>
1176inline typename std::enable_if_t<is_quant8<T>::value>
1177BroadcastMulDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
1178 const T *input1_data, const Shape &input2_shape, const T *input2_data,
1179 const Shape &output_shape, T *output_data)
1180{
1182 {
1183 const std::function<T(const BinaryArithmeticOpParam &, const T &, const T &)> fn =
1184 [](const BinaryArithmeticOpParam &params, const T &a, const T &b) {
1185 return static_cast<T>(quant8_mul(params, a, b));
1186 };
1187 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
1188 input2_data, output_shape, output_data, fn);
1189 return;
1190 }
1192 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
1193 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const T *, const T *, T *)>(
1195 static_cast<void (*)(int, const BinaryArithmeticOpParam &, T, const T *, T *)>(
1197}
1198
1199inline void BroadcastMulDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
1200 const float *input1_data, const Shape &input2_shape,
1201 const float *input2_data, const Shape &output_shape,
1202 float *output_data)
1203{
1205 {
1206 // TODO: Use GetBinaryArithmeticFn
1207 const std::function<float(const float &, const float &)> fn =
1208 [](const float &a, const float &b) -> float { return a * b; };
1209 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
1210 input2_data, output_shape, output_data, fn);
1211 return;
1212 }
1213 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncMulFloat>(params);
1214 BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape, input2_data,
1215 output_shape, output_data, implFuncs.first, implFuncs.second);
1216}
1217
1218inline void Div(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
1219 const float *input1_data, const Shape &input2_shape, const float *input2_data,
1220 const Shape &output_shape, float *output_data)
1221{
1222#ifdef __aarch64__
1223 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
1224 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncDivFloat>(params);
1225 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
1226#else
1227 const std::function<float(const float &, const float &)> fn =
1228 [](const float &a, const float &b) -> float { return a / b; };
1229 reference::BinaryArithmeticOp(params, input1_shape, input1_data, input2_shape, input2_data,
1230 output_shape, output_data, fn);
1231#endif // __aarch64__
1232}
1233
1234inline void BroadcastDivDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
1235 const float *input1_data, const Shape &input2_shape,
1236 const float *input2_data, const Shape &output_shape,
1237 float *output_data)
1238{
1239#ifdef __aarch64__
1241 {
1242 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncDivFloat>(params);
1243 BinaryBroadcastFiveFold(params, false, input1_shape, input1_data, input2_shape, input2_data,
1244 output_shape, output_data, implFuncs.first, implFuncs.second);
1245 }
1247 {
1248 auto implFuncs =
1249 getBinaryOpWithActivationImplFloat<BinaryOpFuncSwapArgs<BinaryOpFuncDivFloat>>(params);
1250 BinaryBroadcastFiveFold(params, true, input1_shape, input1_data, input2_shape, input2_data,
1251 output_shape, output_data, implFuncs.first, implFuncs.second);
1252 }
1253 else
1254#endif // __aarch64__
1255 {
1256 const std::function<float(const float &, const float &)> fn =
1257 [](const float &a, const float &b) -> float { return a / b; };
1258 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
1259 input2_data, output_shape, output_data, fn);
1260 }
1261}
1262
1263} // namespace optimized
1264} // namespace cker
1265} // namespace nnfw
1266
1267#endif // __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
const luci_interpreter::RuntimeShape output_shape
void BroadcastDivDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
std::enable_if_t< is_quant8< T >::value > Mul(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
void MulSimpleBroadcast(int size, const BinaryArithmeticOpParam &params, const uint8_t broadcast_value, const uint8_t *input2_data, uint8_t *output_data)
std::enable_if_t< is_quant8< T >::value, int32_t > quant8_mul(const BinaryArithmeticOpParam &params, const T input1_data, const T input2_data)
std::enable_if_t< is_quant8< T >::value, int32_t > quant8_sum(const BinaryArithmeticOpParam &params, const T input1_data, const T input2_data)
void AddElementwise(int size, const BinaryArithmeticOpParam &params, const uint8_t *input1_data, const uint8_t *input2_data, uint8_t *output_data)
void BinaryOpScalarBroadcast(int size, const BinaryArithmeticOpParam &params, const float broadcast_value, const float *input2_data, float *output_data)
std::enable_if_t< is_quant8< T >::value > BroadcastAddDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
void Sub(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
void MulElementwise(int size, const BinaryArithmeticOpParam &params, const uint8_t *input1_data, const uint8_t *input2_data, uint8_t *output_data)
void AddScalarBroadcast(int size, const BinaryArithmeticOpParam &params, uint8_t broadcast_value, const uint8_t *input2_data, uint8_t *output_data)
std::enable_if_t< is_quant8< T >::value > BroadcastMulDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
void BroadcastSubDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
BinaryOpImplFloatFuncs getBinaryOpWithActivationImplFloat(const BinaryArithmeticOpParam &params)
void BinaryOpElementwise(int size, const BinaryArithmeticOpParam &params, const float *input1_data, const float *input2_data, float *output_data)
void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam &params, bool switch_inputs, const Shape &, const T *unswitched_input1_data, const Shape &, const T *unswitched_input2_data, const Shape &, T *output_data, ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
void Div(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, float *output_data)
std::pair< void(*)(int, const BinaryArithmeticOpParam &, const float *, const float *, float *), void(*)(int, const BinaryArithmeticOpParam &, const float, const float *, float *)> BinaryOpImplFloatFuncs
std::enable_if_t< is_quant8< T >::value > Add(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data)
std::enable_if_t< is_quant8< T >::value > BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data, const std::function< T(const BinaryArithmeticOpParam &params, const T &, const T &)> &fn)
void BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data, const std::function< T(const T &, const T &)> &fn)
int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0, const Shape &check_shape_1)
Definition Shape.h:333
int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
Definition Utils.h:96
int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x, int32_t quantized_multiplier, int left_shift)
Definition Utils.h:111
Definition topk_v2.h:30
int32_t size[5]
Definition Slice.cpp:35
BroadcastableOpCategory broadcast_category
Definition Types.h:181
static float applyFloor(const float value, const float floorParam)
static float applyCeiling(const float value, const float ceilingParam)
static float applyFloor(const float value, const float floorParam)
static float applyCeiling(const float value, const float ceilingParam)
static float applyFloor(const float value, const float floorParam)
static float applyCeiling(const float value, const float ceilingParam)
static float calculate(const float a, const float b)
static float calculate(const float a, const float b)
static float calculate(const float a, const float b)
static float calculate(const float a, const float b)
static T calculate(const T &a, const T &b)