35namespace depthwise_conv
40template <
bool kAllowStr
ided,
int kFixedInputDepth,
int kFixedDepthMultiplier>
48 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
49 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
50 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
53 (void)depth_multiplier;
55 uint8x8x2_t filter_u8;
56 filter_u8.val[0] = vld1_u8(filter_ptr);
57 filter_u8.val[1] = vld1_u8(filter_ptr + 8);
59 for (
int i = 0; i < 2; i++)
62 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])), vdupq_n_s16(filter_offset));
65 for (
int outp = 0; outp < num_output_pixels; outp++)
69 for (
int i = 0; i < 2; i++)
71 acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
72 acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
75 const uint8x8_t input_u8 = vld1_u8(input_ptr);
76 input_ptr += input_ptr_increment;
77 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
78 const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
80 const int16x8x2_t input_dup2 = vzipq_s16(input, input);
82 for (
int i = 0; i < 2; i++)
85 vmlal_s16(acc[0].val[i], vget_low_s16(filter[i]), vget_low_s16(input_dup2.val[i]));
87 vmlal_s16(acc[1].val[i], vget_high_s16(filter[i]), vget_high_s16(input_dup2.val[i]));
90 for (
int i = 0; i < 2; i++)
92 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
93 vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
100template <>
struct QuantizedDepthwiseConvKernel<false, 8, 1>
102 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
103 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
104 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
107 (void)depth_multiplier;
108 (void)input_ptr_increment;
110 const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
111 const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
112 const int16x8_t
filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
116 for (; outp <= num_output_pixels - 2; outp += 2)
120 for (
int i = 0; i < 4; i++)
122 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
125 uint8x8_t input_u8[2];
126 for (
int i = 0; i < 2; i++)
128 input_u8[i] = vld1_u8(input_ptr + 8 * i);
132 for (
int i = 0; i < 2; i++)
134 input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
136 for (
int i = 0; i < 2; i++)
138 input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
141 acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input[0]));
142 acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input[0]));
143 acc[2] = vmlal_s16(acc[2], vget_low_s16(filter), vget_low_s16(input[1]));
144 acc[3] = vmlal_s16(acc[3], vget_high_s16(filter), vget_high_s16(input[1]));
146 for (
int i = 0; i < 4; i++)
148 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
150 acc_buffer_ptr += 16;
153 for (; outp < num_output_pixels; outp++)
157 acc[0] = vld1q_s32(acc_buffer_ptr);
158 acc[1] = vld1q_s32(acc_buffer_ptr + 4);
161 const uint8x8_t input_u8 = vld1_u8(input_ptr);
163 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
164 const int16x8_t
input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
166 acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input));
167 acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input));
169 vst1q_s32(acc_buffer_ptr, acc[0]);
170 vst1q_s32(acc_buffer_ptr + 4, acc[1]);
176template <>
struct QuantizedDepthwiseConvKernel<false, 4, 2>
178 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
179 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
180 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
183 (void)depth_multiplier;
184 (void)input_ptr_increment;
186 const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
187 const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
188 const int16x8_t
filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
192 for (; outp <= num_output_pixels - 2; outp += 2)
196 for (
int i = 0; i < 4; i++)
198 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
201 const uint8x8_t input_u8 = vld1_u8(input_ptr);
203 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
204 const int16x8_t
input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
206 const int16x8x2_t input_dup2 = vzipq_s16(input, input);
208 for (
int i = 0; i < 2; i++)
211 vmlal_s16(acc[2 * i + 0], vget_low_s16(filter), vget_low_s16(input_dup2.val[i]));
213 vmlal_s16(acc[2 * i + 1], vget_high_s16(filter), vget_high_s16(input_dup2.val[i]));
216 for (
int i = 0; i < 4; i++)
218 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
220 acc_buffer_ptr += 16;
223 for (; outp < num_output_pixels; outp++)
227 for (
int i = 0; i < 2; i++)
229 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
232 uint8x8_t input_u8 = vdup_n_u8(0);
233 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
234 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
235 input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
236 input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
238 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
239 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
241 const int16x4x2_t input_dup2 = vzip_s16(input, input);
243 acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), input_dup2.val[0]);
244 acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), input_dup2.val[1]);
246 for (
int i = 0; i < 2; i++)
248 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
255template <>
struct QuantizedDepthwiseConvKernel<false, 2, 8>
257 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
258 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
259 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
262 (void)depth_multiplier;
263 (void)input_ptr_increment;
266 for (
int i = 0; i < 2; i++)
268 const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i);
269 const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
270 filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
274 for (; outp <= num_output_pixels - 2; outp += 2)
278 for (
int i = 0; i < 8; i++)
280 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
283 uint8x8_t input_u8 = vdup_n_u8(0);
284 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
285 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
286 input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
287 input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
289 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
290 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
292 acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
293 acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
294 acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
295 acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
296 acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), input, 2);
297 acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), input, 2);
298 acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), input, 3);
299 acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), input, 3);
301 for (
int i = 0; i < 8; i++)
303 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
305 acc_buffer_ptr += 32;
308 for (; outp < num_output_pixels; outp++)
312 for (
int i = 0; i < 4; i++)
314 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
317 uint8x8_t input_u8 = vdup_n_u8(0);
318 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
319 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
321 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
322 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
325 acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
326 acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
327 acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
328 acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
331 for (
int i = 0; i < 4; i++)
333 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
335 acc_buffer_ptr += 16;
340template <>
struct QuantizedDepthwiseConvKernel<false, 2, 2>
342 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
343 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
344 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
347 (void)depth_multiplier;
348 (void)input_ptr_increment;
350 uint8x8_t filter_u8 = vdup_n_u8(0);
351 filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
352 filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
353 filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
354 filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
355 const int16x4_t filter_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
356 const int16x4_t
filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
360 for (; outp <= num_output_pixels - 4; outp += 4)
364 for (
int i = 0; i < 4; i++)
366 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
370 const uint8x8_t input_u8 = vld1_u8(input_ptr);
372 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
373 const int16x8_t
input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
375 const int16x8x2_t input_dup2 = vzipq_s16(input, input);
377 acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
378 acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
379 acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
380 acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
382 for (
int i = 0; i < 4; i++)
384 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
386 acc_buffer_ptr += 16;
389 for (; outp < num_output_pixels; outp++)
392 int32x4_t acc = vld1q_s32(acc_buffer_ptr);
394 uint8x8_t input_u8 = vdup_n_u8(0);
395 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
396 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
398 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
399 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
401 const int16x4_t input_dup2 = vzip_s16(input, input).val[0];
403 acc = vmlal_s16(acc, filter, input_dup2);
405 vst1q_s32(acc_buffer_ptr, acc);
411template <>
struct QuantizedDepthwiseConvKernel<false, 2, 1>
413 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
414 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
415 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
418 (void)depth_multiplier;
419 (void)input_ptr_increment;
421 uint8x8_t filter_u8 = vdup_n_u8(0);
422 filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
423 filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
424 filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
425 filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
426 const int16x4_t filter_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
427 const int16x4_t
filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
431 for (; outp <= num_output_pixels - 8; outp += 8)
435 for (
int i = 0; i < 4; i++)
437 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
440 uint8x8_t input_u8[2];
441 for (
int i = 0; i < 2; i++)
443 input_u8[i] = vld1_u8(input_ptr + 8 * i);
447 for (
int i = 0; i < 2; i++)
449 input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
451 for (
int i = 0; i < 2; i++)
453 input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
457 acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input[0]));
458 acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input[0]));
459 acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input[1]));
460 acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input[1]));
462 for (
int i = 0; i < 4; i++)
464 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
466 acc_buffer_ptr += 16;
469 for (; outp <= num_output_pixels - 4; outp += 4)
473 for (
int i = 0; i < 2; i++)
475 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
478 const uint8x8_t input_u8 = vld1_u8(input_ptr);
480 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
481 const int16x8_t
input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
484 acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input));
485 acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input));
487 for (
int i = 0; i < 2; i++)
489 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
494 for (; outp <= num_output_pixels - 2; outp += 2)
497 int32x4_t acc = vld1q_s32(acc_buffer_ptr);
499 uint8x8_t input_u8 = vdup_n_u8(0);
500 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
501 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
502 input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
503 input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
505 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
506 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
509 acc = vmlal_s16(acc, filter, input);
511 vst1q_s32(acc_buffer_ptr, acc);
515 for (; outp < num_output_pixels; outp++)
518 int32x2_t acc = vld1_s32(acc_buffer_ptr);
520 uint8x8_t input_u8 = vdup_n_u8(0);
521 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
522 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
524 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
525 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
528 acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
530 vst1_s32(acc_buffer_ptr, acc);
536template <>
struct QuantizedDepthwiseConvKernel<false, 1, 2>
538 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
539 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
540 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
543 (void)depth_multiplier;
544 (void)input_ptr_increment;
546 uint8x8_t filter_u8 = vdup_n_u8(0);
547 filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
548 filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
549 filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
550 filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
551 const int16x4_t filter_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
552 const int16x4_t
filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
556 for (; outp <= num_output_pixels - 8; outp += 8)
560 for (
int i = 0; i < 4; i++)
562 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
566 const uint8x8_t input_u8 = vld1_u8(input_ptr);
568 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
569 const int16x8_t
input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
571 const int16x8x2_t input_dup2 = vzipq_s16(input, input);
573 acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
574 acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
575 acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
576 acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
578 for (
int i = 0; i < 4; i++)
580 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
582 acc_buffer_ptr += 16;
585 for (; outp < num_output_pixels; outp++)
588 int32x2_t acc = vld1_s32(acc_buffer_ptr);
591 const uint32_t
input = *input_ptr++ + input_offset;
594 acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input));
596 vst1_s32(acc_buffer_ptr, acc);
602template <>
struct QuantizedDepthwiseConvKernel<false, 1, 4>
604 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
605 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
606 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
609 (void)depth_multiplier;
610 (void)input_ptr_increment;
612 uint8x8_t filter_u8 = vdup_n_u8(0);
613 filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
614 filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
615 filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
616 filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
617 const int16x4_t filter_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
618 const int16x4_t
filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
622 for (; outp <= num_output_pixels - 8; outp += 8)
626 for (
int i = 0; i < 8; i++)
628 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
632 uint8x8_t input_u8 = vld1_u8(input_ptr);
634 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
635 const int16x8_t
input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
638 acc[0] = vmlal_lane_s16(acc[0], filter, vget_low_s16(input), 0);
639 acc[1] = vmlal_lane_s16(acc[1], filter, vget_low_s16(input), 1);
640 acc[2] = vmlal_lane_s16(acc[2], filter, vget_low_s16(input), 2);
641 acc[3] = vmlal_lane_s16(acc[3], filter, vget_low_s16(input), 3);
642 acc[4] = vmlal_lane_s16(acc[4], filter, vget_high_s16(input), 0);
643 acc[5] = vmlal_lane_s16(acc[5], filter, vget_high_s16(input), 1);
644 acc[6] = vmlal_lane_s16(acc[6], filter, vget_high_s16(input), 2);
645 acc[7] = vmlal_lane_s16(acc[7], filter, vget_high_s16(input), 3);
648 for (
int i = 0; i < 8; i++)
650 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
652 acc_buffer_ptr += 32;
655 for (; outp <= num_output_pixels - 4; outp += 4)
659 for (
int i = 0; i < 4; i++)
661 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
665 uint8x8_t input_u8 = vdup_n_u8(0);
666 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
667 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
668 input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
669 input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
671 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
672 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
675 acc[0] = vmlal_lane_s16(acc[0], filter, input, 0);
676 acc[1] = vmlal_lane_s16(acc[1], filter, input, 1);
677 acc[2] = vmlal_lane_s16(acc[2], filter, input, 2);
678 acc[3] = vmlal_lane_s16(acc[3], filter, input, 3);
681 for (
int i = 0; i < 4; i++)
683 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
685 acc_buffer_ptr += 16;
688 for (; outp < num_output_pixels; outp++)
691 int32x4_t acc = vld1q_s32(acc_buffer_ptr);
694 const uint32_t
input = *input_ptr++ + input_offset;
697 acc = vmlal_n_s16(acc, filter, input);
699 vst1q_s32(acc_buffer_ptr, acc);
705template <>
struct QuantizedDepthwiseConvKernel<false, 4, 1>
707 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
708 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
709 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
712 (void)depth_multiplier;
713 (void)input_ptr_increment;
715 uint8x8_t filter_u8 = vdup_n_u8(0);
716 filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
717 filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
718 filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
719 filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
720 const int16x4_t filter_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
721 const int16x4_t
filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
725 for (; outp <= num_output_pixels - 4; outp += 4)
729 for (
int i = 0; i < 4; i++)
731 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
735 for (
int i = 0; i < 2; i++)
737 const uint8x8_t input_u8 = vld1_u8(input_ptr + 8 * i);
738 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
739 input[i] = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
743 for (
int i = 0; i < 2; i++)
745 acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], filter, vget_low_s16(input[i]));
746 acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], filter, vget_high_s16(input[i]));
749 for (
int i = 0; i < 4; i++)
751 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
753 acc_buffer_ptr += 16;
756 for (; outp < num_output_pixels; outp++)
760 acc = vld1q_s32(acc_buffer_ptr);
763 uint8x8_t input_u8 = vdup_n_u8(0);
764 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
765 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
766 input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
767 input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
769 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
770 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
772 acc = vmlal_s16(acc, filter, input);
774 vst1q_s32(acc_buffer_ptr, acc);
780template <>
struct QuantizedDepthwiseConvKernel<false, 4, 4>
782 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
783 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
784 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
787 (void)depth_multiplier;
788 (void)input_ptr_increment;
791 for (
int i = 0; i < 2; i++)
793 const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i);
794 const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
795 filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
800 for (; outp <= num_output_pixels - 2; outp += 2)
804 for (
int i = 0; i < 8; i++)
806 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
810 uint8x8_t input_u8 = vld1_u8(input_ptr);
812 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
813 const int16x8_t
input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
816 acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), vget_low_s16(input), 0);
817 acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), vget_low_s16(input), 1);
818 acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), vget_low_s16(input), 2);
819 acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), vget_low_s16(input), 3);
820 acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), vget_high_s16(input), 0);
821 acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), vget_high_s16(input), 1);
822 acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), vget_high_s16(input), 2);
823 acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), vget_high_s16(input), 3);
825 for (
int i = 0; i < 8; i++)
827 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
829 acc_buffer_ptr += 32;
832 for (; outp < num_output_pixels; outp++)
836 for (
int i = 0; i < 4; i++)
838 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
842 uint8x8_t input_u8 = vdup_n_u8(0);
843 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
844 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
845 input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
846 input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
848 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
849 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
852 acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
853 acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 1);
854 acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 2);
855 acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 3);
857 for (
int i = 0; i < 4; i++)
859 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
861 acc_buffer_ptr += 16;
866template <>
struct QuantizedDepthwiseConvKernel<true, 0, 3>
868 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
869 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
870 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
873 (void)depth_multiplier;
877 static const uint8_t dup3_indices_array[3][8] = {
878 {0, 0, 0, 1, 1, 1, 2, 2}, {2, 3, 3, 3, 4, 4, 4, 5}, {5, 5, 6, 6, 6, 7, 7, 7}};
879 uint8x8_t dup3_indices[3];
880 for (
int i = 0; i < 3; i++)
882 dup3_indices[i] = vld1_u8(dup3_indices_array[i]);
886 for (
int outp = 0; outp < num_output_pixels; outp++)
888 const uint8_t *local_filter_ptr = filter_ptr;
889 const uint8_t *local_input_ptr = input_ptr;
892 for (; ic <= input_depth - 8; ic += 8)
896 uint8x8x3_t filter_u8;
897 filter_u8.val[0] = vld1_u8(local_filter_ptr);
898 filter_u8.val[1] = vld1_u8(local_filter_ptr + 8);
899 filter_u8.val[2] = vld1_u8(local_filter_ptr + 16);
900 local_filter_ptr += 24;
901 for (
int i = 0; i < 3; i++)
903 const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i]));
904 filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
907 const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
908 local_input_ptr += 8;
910 uint8x8_t input_u8_dup3[3];
911 for (
int i = 0; i < 3; i++)
913 input_u8_dup3[i] = vtbl1_u8(input_u8, dup3_indices[i]);
915 int16x8_t input_dup3[3];
916 for (
int i = 0; i < 3; i++)
918 const int16x8_t input_s16_dup3 = vreinterpretq_s16_u16(vmovl_u8(input_u8_dup3[i]));
919 input_dup3[i] = vaddq_s16(input_s16_dup3, vdupq_n_s16(input_offset));
923 for (
int i = 0; i < 2; i++)
925 acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
926 acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
927 acc[i].val[2] = vld1q_s32(acc_buffer_ptr + 4 * i + 16);
930 for (
int j = 0; j < 3; j++)
933 vmlal_s16(acc[0].val[j], vget_low_s16(input_dup3[j]), vget_low_s16(filter[j]));
935 vmlal_s16(acc[1].val[j], vget_high_s16(input_dup3[j]), vget_high_s16(filter[j]));
938 for (
int i = 0; i < 2; i++)
940 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
941 vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
942 vst1q_s32(acc_buffer_ptr + 4 * i + 16, acc[i].val[2]);
944 acc_buffer_ptr += 24;
947 for (; ic < input_depth; ic++)
949 const int16_t input_val = *local_input_ptr++ + input_offset;
950 for (
int i = 0; i < 3; i++)
952 const int16_t filter_val = local_filter_ptr[i] + filter_offset;
953 *acc_buffer_ptr++ +=
static_cast<int32_t
>(filter_val) * input_val;
955 local_filter_ptr += 3;
957 input_ptr += input_ptr_increment;
962template <>
struct QuantizedDepthwiseConvKernel<true, 0, 2>
964 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
965 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
966 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
969 (void)depth_multiplier;
971 for (
int outp = 0; outp < num_output_pixels; outp++)
973 const uint8_t *local_filter_ptr = filter_ptr;
974 const uint8_t *local_input_ptr = input_ptr;
977 for (; ic <= input_depth - 8; ic += 8)
981 uint8x8x2_t filter_u8;
982 filter_u8.val[0] = vld1_u8(local_filter_ptr);
983 filter_u8.val[1] = vld1_u8(local_filter_ptr + 8);
984 local_filter_ptr += 16;
985 for (
int i = 0; i < 2; i++)
987 const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i]));
988 filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
991 const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
992 local_input_ptr += 8;
993 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
994 const int16x8_t
input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
995 const int16x8x2_t input_dup2 = vzipq_s16(input, input);
998 for (
int i = 0; i < 2; i++)
1000 acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
1001 acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
1004 for (
int j = 0; j < 2; j++)
1007 vmlal_s16(acc[0].val[j], vget_low_s16(filter[j]), vget_low_s16(input_dup2.val[j]));
1009 vmlal_s16(acc[1].val[j], vget_high_s16(filter[j]), vget_high_s16(input_dup2.val[j]));
1012 for (
int i = 0; i < 2; i++)
1014 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
1015 vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
1017 acc_buffer_ptr += 16;
1020 for (; ic < input_depth; ic++)
1023 const int16_t input_val = *local_input_ptr++ + input_offset;
1024 for (
int i = 0; i < 2; i++)
1026 const int16_t filter_val = local_filter_ptr[i] + filter_offset;
1027 *acc_buffer_ptr++ +=
static_cast<int32_t
>(filter_val) * input_val;
1029 local_filter_ptr += 2;
1031 input_ptr += input_ptr_increment;
1036template <>
struct QuantizedDepthwiseConvKernel<true, 0, 1>
1038 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
1039 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
1040 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
1043 (void)depth_multiplier;
1045 for (
int outp = 0; outp < num_output_pixels; outp++)
1047 const uint8_t *local_filter_ptr = filter_ptr;
1048 const uint8_t *local_input_ptr = input_ptr;
1051 for (; ic <= input_depth - 16; ic += 16)
1054 uint8x8_t filter_u8_0 = vld1_u8(local_filter_ptr + 8 * 0);
1055 uint8x8_t filter_u8_1 = vld1_u8(local_filter_ptr + 8 * 1);
1056 local_filter_ptr += 16;
1057 int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
1058 int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1059 filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
1060 filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
1062 uint8x8_t input_u8_0 = vld1_u8(local_input_ptr + 8 * 0);
1063 uint8x8_t input_u8_1 = vld1_u8(local_input_ptr + 8 * 1);
1064 local_input_ptr += 16;
1065 int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0));
1066 int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1));
1067 input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
1068 input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
1070 int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1071 int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1072 int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1073 int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
1074 acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), vget_low_s16(filter_0));
1075 acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), vget_high_s16(filter_0));
1076 acc_2 = vmlal_s16(acc_2, vget_low_s16(input_1), vget_low_s16(filter_1));
1077 acc_3 = vmlal_s16(acc_3, vget_high_s16(input_1), vget_high_s16(filter_1));
1079 vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1080 vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1081 vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1082 vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
1083 acc_buffer_ptr += 16;
1086 for (; ic <= input_depth - 8; ic += 8)
1089 const uint8x8_t filter_u8 = vld1_u8(local_filter_ptr);
1090 local_filter_ptr += 8;
1091 const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
1092 const int16x8_t
filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
1094 const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
1095 local_input_ptr += 8;
1096 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
1097 const int16x8_t
input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
1100 for (
int i = 0; i < 2; i++)
1102 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1105 acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
1106 acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
1108 for (
int i = 0; i < 2; i++)
1110 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1112 acc_buffer_ptr += 8;
1115 for (; ic < input_depth; ic++)
1117 const int16_t input_val = *local_input_ptr++ + input_offset;
1118 const int16_t filter_val = *local_filter_ptr++ + filter_offset;
1119 *acc_buffer_ptr++ +=
static_cast<int32_t
>(filter_val) * input_val;
1121 input_ptr += input_ptr_increment;
1126template <>
struct QuantizedDepthwiseConvKernel<true, 16, 1>
1128 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
1129 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
1130 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
1133 (void)depth_multiplier;
1135 uint8x8_t filter_u8[2];
1136 for (
int i = 0; i < 2; i++)
1138 filter_u8[i] = vld1_u8(filter_ptr + 8 * i);
1141 for (
int i = 0; i < 2; i++)
1143 filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i]));
1145 for (
int i = 0; i < 2; i++)
1147 filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset));
1150 for (
int outp = 0; outp < num_output_pixels; outp++)
1153 uint8x8_t input_u8[2];
1154 for (
int i = 0; i < 2; i++)
1156 input_u8[i] = vld1_u8(input_ptr + 8 * i);
1158 input_ptr += input_ptr_increment;
1160 for (
int i = 0; i < 2; i++)
1162 input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
1164 for (
int i = 0; i < 2; i++)
1166 input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
1170 for (
int i = 0; i < 4; i++)
1172 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1175 for (
int i = 0; i < 2; i++)
1177 acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(input[i]), vget_low_s16(filter[i]));
1179 vmlal_s16(acc[2 * i + 1], vget_high_s16(input[i]), vget_high_s16(filter[i]));
1182 for (
int i = 0; i < 4; i++)
1184 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1186 acc_buffer_ptr += 16;
1191template <>
struct QuantizedDepthwiseConvKernel<true, 8, 1>
1193 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
1194 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
1195 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
1198 (void)depth_multiplier;
1200 const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
1201 const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
1202 const int16x8_t
filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
1204 for (
int outp = 0; outp < num_output_pixels; outp++)
1207 const uint8x8_t input_u8 = vld1_u8(input_ptr);
1208 const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
1209 const int16x8_t
input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
1212 for (
int i = 0; i < 2; i++)
1214 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1217 acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
1218 acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
1220 for (
int i = 0; i < 2; i++)
1222 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1224 acc_buffer_ptr += 8;
1225 input_ptr += input_ptr_increment;
1230template <>
struct QuantizedDepthwiseConvKernel<true, 1, 16>
1232 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
1233 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
1234 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
1237 (void)depth_multiplier;
1239 uint8x8_t filter_u8[2];
1240 for (
int i = 0; i < 2; i++)
1242 filter_u8[i] = vld1_u8(filter_ptr + 8 * i);
1245 for (
int i = 0; i < 2; i++)
1247 filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i]));
1249 for (
int i = 0; i < 2; i++)
1251 filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset));
1254 for (
int outp = 0; outp < num_output_pixels; outp++)
1256 uint8_t input_u8 = *input_ptr;
1257 input_ptr += input_ptr_increment;
1258 int16_t
input =
static_cast<int16_t
>(input_u8) + input_offset;
1261 for (
int i = 0; i < 4; i++)
1263 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1266 for (
int i = 0; i < 2; i++)
1268 acc[2 * i + 0] = vmlal_n_s16(acc[2 * i + 0], vget_low_s16(filter[i]), input);
1269 acc[2 * i + 1] = vmlal_n_s16(acc[2 * i + 1], vget_high_s16(filter[i]), input);
1272 for (
int i = 0; i < 4; i++)
1274 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1276 acc_buffer_ptr += 16;
1281template <>
struct QuantizedDepthwiseConvKernel<true, 1, 32>
1283 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
1284 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
1285 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
1288 (void)depth_multiplier;
1290 uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0);
1291 uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1);
1292 uint8x8_t filter_u8_2 = vld1_u8(filter_ptr + 8 * 2);
1293 uint8x8_t filter_u8_3 = vld1_u8(filter_ptr + 8 * 3);
1294 int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
1295 int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1296 int16x8_t filter_2 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_2));
1297 int16x8_t filter_3 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_3));
1298 filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
1299 filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
1300 filter_2 = vaddq_s16(filter_2, vdupq_n_s16(filter_offset));
1301 filter_3 = vaddq_s16(filter_3, vdupq_n_s16(filter_offset));
1303 for (
int outp = 0; outp < num_output_pixels; outp++)
1305 uint8_t input_u8 = *input_ptr;
1306 input_ptr += input_ptr_increment;
1307 int16_t
input =
static_cast<int16_t
>(input_u8) + input_offset;
1309 int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1310 int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1311 int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1312 int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
1313 int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4);
1314 int32x4_t acc_5 = vld1q_s32(acc_buffer_ptr + 4 * 5);
1315 int32x4_t acc_6 = vld1q_s32(acc_buffer_ptr + 4 * 6);
1316 int32x4_t acc_7 = vld1q_s32(acc_buffer_ptr + 4 * 7);
1318 acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input);
1319 acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input);
1320 acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input);
1321 acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input);
1322 acc_4 = vmlal_n_s16(acc_4, vget_low_s16(filter_2), input);
1323 acc_5 = vmlal_n_s16(acc_5, vget_high_s16(filter_2), input);
1324 acc_6 = vmlal_n_s16(acc_6, vget_low_s16(filter_3), input);
1325 acc_7 = vmlal_n_s16(acc_7, vget_high_s16(filter_3), input);
1327 vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1328 vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1329 vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1330 vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
1331 vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4);
1332 vst1q_s32(acc_buffer_ptr + 4 * 5, acc_5);
1333 vst1q_s32(acc_buffer_ptr + 4 * 6, acc_6);
1334 vst1q_s32(acc_buffer_ptr + 4 * 7, acc_7);
1335 acc_buffer_ptr += 32;
1340template <>
struct QuantizedDepthwiseConvKernel<true, 1, 20>
1342 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
1343 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
1344 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
1347 (void)depth_multiplier;
1354 uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0);
1355 uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1);
1356 uint8x8_t filter_u8_x = vld1_u8(filter_ptr + 8 * 1 + 4);
1357 int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
1358 int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1359 int16x8_t filter_x = vreinterpretq_s16_u16(vmovl_u8(filter_u8_x));
1360 filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
1361 filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
1362 filter_x = vaddq_s16(filter_x, vdupq_n_s16(filter_offset));
1364 for (
int outp = 0; outp < num_output_pixels; outp++)
1366 uint8_t input_u8 = *input_ptr;
1367 input_ptr += input_ptr_increment;
1368 int16_t
input =
static_cast<int16_t
>(input_u8) + input_offset;
1370 int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1371 int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1372 int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1373 int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
1374 int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4);
1376 acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input);
1377 acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input);
1378 acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input);
1379 acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input);
1380 acc_4 = vmlal_n_s16(acc_4, vget_high_s16(filter_x), input);
1382 vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1383 vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1384 vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1385 vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
1386 vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4);
1387 acc_buffer_ptr += 20;
1392template <>
struct QuantizedDepthwiseConvKernel<true, 1, 8>
1394 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
1395 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
1396 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
1399 (void)depth_multiplier;
1401 const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
1403 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(filter_u8)), vdupq_n_s16(filter_offset));
1405 for (
int outp = 0; outp < num_output_pixels; outp++)
1407 uint8_t input_u8 = *input_ptr;
1408 input_ptr += input_ptr_increment;
1409 int16_t
input =
static_cast<int16_t
>(input_u8) + input_offset;
1412 for (
int i = 0; i < 2; i++)
1414 acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1417 acc[0] = vmlal_n_s16(acc[0], vget_low_s16(filter), input);
1418 acc[1] = vmlal_n_s16(acc[1], vget_high_s16(filter), input);
1420 for (
int i = 0; i < 2; i++)
1422 vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1424 acc_buffer_ptr += 8;
1429template <>
struct QuantizedDepthwiseConvKernel<true, 2, 1>
1431 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
1432 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
1433 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
1436 (void)depth_multiplier;
1438 uint8x8_t filter_u8 = vdup_n_u8(0);
1439 filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
1440 filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
1441 filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
1442 filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
1443 const int16x4_t filter_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
1444 const int16x4_t
filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
1449 for (; outp <= num_output_pixels - 2; outp += 2)
1452 int32x4_t acc = vld1q_s32(acc_buffer_ptr);
1454 uint16x4_t input_u16 = vdup_n_u16(0);
1455 input_u16 = vset_lane_u16((
reinterpret_cast<const uint16_t *
>(input_ptr))[0], input_u16, 0);
1456 input_ptr += input_ptr_increment;
1457 input_u16 = vset_lane_u16((
reinterpret_cast<const uint16_t *
>(input_ptr))[0], input_u16, 1);
1458 input_ptr += input_ptr_increment;
1459 const int16x4_t input_s16 =
1460 vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vreinterpret_u8_u16(input_u16))));
1461 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1464 acc = vmlal_s16(acc, filter, input);
1466 vst1q_s32(acc_buffer_ptr, acc);
1467 acc_buffer_ptr += 4;
1471 for (; outp < num_output_pixels; outp++)
1474 int32x2_t acc = vld1_s32(acc_buffer_ptr);
1476 uint8x8_t input_u8 = vdup_n_u8(0);
1477 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
1478 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
1479 input_ptr += input_ptr_increment;
1480 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
1481 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1484 acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
1486 vst1_s32(acc_buffer_ptr, acc);
1487 acc_buffer_ptr += 2;
1492template <>
struct QuantizedDepthwiseConvKernel<true, 4, 1>
1494 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
1495 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
1496 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
1499 (void)depth_multiplier;
1500 if (num_output_pixels <= 0)
1506 uint8x8_t filter_u8 = vdup_n_u8(0);
1507 filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
1508 filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
1509 filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
1510 filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
1511 const int16x4_t filter_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
1512 const int16x4_t
filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
1519 for (; outp < num_output_pixels - 1; outp++)
1523 acc = vld1q_s32(acc_buffer_ptr);
1526 uint8x8_t input_u8 = vld1_u8(input_ptr);
1527 input_ptr += input_ptr_increment;
1528 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
1529 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1531 acc = vmlal_s16(acc, filter, input);
1533 vst1q_s32(acc_buffer_ptr, acc);
1534 acc_buffer_ptr += 4;
1540 acc = vld1q_s32(acc_buffer_ptr);
1543 uint8x8_t input_u8 = vdup_n_u8(0);
1544 input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
1545 input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
1546 input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
1547 input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
1548 const int16x4_t input_s16 = vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
1549 const int16x4_t
input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1551 acc = vmlal_s16(acc, filter, input);
1553 vst1q_s32(acc_buffer_ptr, acc);
1557template <>
struct QuantizedDepthwiseConvKernel<false, 12, 1>
1559 static void Run(
int num_output_pixels,
int input_depth,
int depth_multiplier,
1560 const uint8_t *input_ptr, int16_t input_offset,
int input_ptr_increment,
1561 const uint8_t *filter_ptr, int16_t filter_offset, int32_t *acc_buffer_ptr)
1564 (void)depth_multiplier;
1566 uint8x8_t filter_u8_0 = vld1_u8(filter_ptr);
1567 uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 4);
1568 int16x8_t filter_s16_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
1569 int16x8_t filter_s16_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1570 filter_s16_0 = vaddq_s16(filter_s16_0, vdupq_n_s16(filter_offset));
1571 filter_s16_1 = vaddq_s16(filter_s16_1, vdupq_n_s16(filter_offset));
1572 int16x4_t filter_0 = vget_low_s16(filter_s16_0);
1573 int16x4_t filter_1 = vget_high_s16(filter_s16_0);
1574 int16x4_t filter_2 = vget_high_s16(filter_s16_1);
1577 for (
int outp = 0; outp < num_output_pixels; outp++)
1580 uint8x8_t input_u8_0 = vld1_u8(input_ptr);
1581 uint8x8_t input_u8_1 = vld1_u8(input_ptr + 4);
1582 input_ptr += input_ptr_increment;
1583 int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0));
1584 int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1));
1585 input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
1586 input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
1589 int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1590 int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1591 int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1594 acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), filter_0);
1595 acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), filter_1);
1596 acc_2 = vmlal_s16(acc_2, vget_high_s16(input_1), filter_2);
1599 vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1600 vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1601 vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1603 acc_buffer_ptr += 12;
1611template <
bool kAllowStr
ided,
int kFixedInputDepth,
int kFixedDepthMultiplier>
1613 int input_width,
const uint8_t *input_data,
1614 int16_t input_offset,
int pad_width,
int depth_multiplier,
1615 int filter_width,
const uint8_t *filter_data,
1616 int16_t filter_offset,
int out_x_buffer_start,
1617 int out_x_buffer_end,
int output_depth, int32_t *acc_buffer)
1622 static_assert(kFixedDepthMultiplier || !kFixedInputDepth,
"");
1623 static_assert(kFixedInputDepth || kAllowStrided,
"");
1624 assert(stride == 1 || kAllowStrided);
1625 if (kFixedInputDepth)
1627 assert(input_depth == kFixedInputDepth);
1629 if (kFixedDepthMultiplier)
1631 assert(depth_multiplier == kFixedDepthMultiplier);
1633 assert(output_depth == input_depth * depth_multiplier);
1634 const int input_ptr_increment = stride * input_depth;
1635 const uint8_t *filter_base_ptr = filter_data;
1636 for (
int filter_x = 0; filter_x < filter_width; ++filter_x)
1640 int out_x_loop_start_unclampled = 0;
1641 int out_x_loop_end_unclampled = 0;
1646 out_x_loop_start_unclampled = (pad_width - dilation_factor * filter_x + 1) / 2;
1647 out_x_loop_end_unclampled = (pad_width + input_width - dilation_factor * filter_x + 1) / 2;
1649 else if (stride == 4)
1651 out_x_loop_start_unclampled = (pad_width - dilation_factor * filter_x + 3) / 4;
1652 out_x_loop_end_unclampled = (pad_width + input_width - dilation_factor * filter_x + 3) / 4;
1656 out_x_loop_start_unclampled =
1657 (pad_width - dilation_factor * filter_x + stride - 1) / stride;
1658 out_x_loop_end_unclampled =
1659 (pad_width + input_width - dilation_factor * filter_x + stride - 1) / stride;
1664 out_x_loop_start_unclampled = pad_width - dilation_factor * filter_x;
1665 out_x_loop_end_unclampled = pad_width + input_width - dilation_factor * filter_x;
1669 const int out_x_loop_start = std::max(out_x_buffer_start, out_x_loop_start_unclampled);
1670 const int out_x_loop_end = std::min(out_x_buffer_end, out_x_loop_end_unclampled);
1672 int32_t *acc_buffer_ptr = acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
1673 const int in_x_origin = (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
1674 const uint8_t *input_ptr = input_data + in_x_origin * input_depth;
1675 const int num_output_pixels = out_x_loop_end - out_x_loop_start;
1677 num_output_pixels, input_depth, depth_multiplier, input_ptr, input_offset,
1678 input_ptr_increment, filter_base_ptr, filter_offset, acc_buffer_ptr);
1679 filter_base_ptr += output_depth;
1685 int input_width,
const uint8_t *input_data,
1686 int16_t input_offset,
int pad_width,
1687 int depth_multiplier,
int filter_width,
1688 const uint8_t *filter_data, int16_t filter_offset,
1689 int out_x_buffer_start,
int out_x_buffer_end,
1690 int output_depth, int32_t *acc_buffer)
1692 const uint8_t *filter_base_ptr = filter_data;
1693 for (
int filter_x = 0; filter_x < filter_width; ++filter_x)
1695 const int out_x_loop_start =
1696 std::max(out_x_buffer_start, (pad_width - dilation_factor * filter_x + stride - 1) / stride);
1697 const int out_x_loop_end =
1698 std::min(out_x_buffer_end,
1699 (pad_width + input_width - dilation_factor * filter_x + stride - 1) / stride);
1701 int32_t *acc_buffer_ptr = acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
1702 const int in_x_origin = (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
1703 const uint8_t *input_ptr = input_data + in_x_origin * input_depth;
1704 const int input_ptr_increment = (stride - 1) * input_depth;
1705 for (
int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++)
1707 const uint8_t *filter_ptr = filter_base_ptr;
1708 for (
int ic = 0; ic < input_depth; ++ic)
1710 const int16_t input_val = *input_ptr++ + input_offset;
1711 for (
int m = 0;
m < depth_multiplier;
m++)
1713 const int16_t filter_val = *filter_ptr++ + filter_offset;
1714 *acc_buffer_ptr++ +=
static_cast<int32_t
>(filter_val) * input_val;
1717 input_ptr += input_ptr_increment;
1719 filter_base_ptr += output_depth;
1725 const int32_t *bias_data, int32_t *acc_buffer)
1729 if (output_depth == 1)
1731 const int32x4_t b = vdupq_n_s32(bias_data[0]);
1732 for (; i <= num_output_pixels - 16; i += 16)
1734 vst1q_s32(acc_buffer + i + 0, b);
1735 vst1q_s32(acc_buffer + i + 4, b);
1736 vst1q_s32(acc_buffer + i + 8, b);
1737 vst1q_s32(acc_buffer + i + 12, b);
1739 for (; i <= num_output_pixels - 4; i += 4)
1741 vst1q_s32(acc_buffer + i, b);
1744 else if (output_depth == 2)
1746 int32x4_t b = vdupq_n_s32(bias_data[0]);
1747 b = vsetq_lane_s32(bias_data[1], b, 1);
1748 b = vsetq_lane_s32(bias_data[1], b, 3);
1749 for (; i <= num_output_pixels - 8; i += 8)
1751 vst1q_s32(acc_buffer + 2 * i + 0, b);
1752 vst1q_s32(acc_buffer + 2 * i + 4, b);
1753 vst1q_s32(acc_buffer + 2 * i + 8, b);
1754 vst1q_s32(acc_buffer + 2 * i + 12, b);
1756 for (; i <= num_output_pixels - 2; i += 2)
1758 vst1q_s32(acc_buffer + 2 * i, b);
1761 else if (output_depth == 4)
1763 const int32x4_t b = vld1q_s32(bias_data);
1764 for (; i <= num_output_pixels - 4; i += 4)
1766 vst1q_s32(acc_buffer + 4 * i + 0, b);
1767 vst1q_s32(acc_buffer + 4 * i + 4, b);
1768 vst1q_s32(acc_buffer + 4 * i + 8, b);
1769 vst1q_s32(acc_buffer + 4 * i + 12, b);
1771 for (; i < num_output_pixels; i++)
1773 vst1q_s32(acc_buffer + 4 * i, b);
1776 else if (output_depth == 8)
1778 const int32x4_t b0 = vld1q_s32(bias_data);
1779 const int32x4_t b1 = vld1q_s32(bias_data + 4);
1780 for (; i <= num_output_pixels - 2; i += 2)
1782 vst1q_s32(acc_buffer + 8 * i + 0, b0);
1783 vst1q_s32(acc_buffer + 8 * i + 4, b1);
1784 vst1q_s32(acc_buffer + 8 * i + 8, b0);
1785 vst1q_s32(acc_buffer + 8 * i + 12, b1);
1787 for (; i < num_output_pixels; i++)
1789 vst1q_s32(acc_buffer + 8 * i + 0, b0);
1790 vst1q_s32(acc_buffer + 8 * i + 4, b1);
1793 else if (output_depth == 16)
1795 const int32x4_t b0 = vld1q_s32(bias_data);
1796 const int32x4_t b1 = vld1q_s32(bias_data + 4);
1797 const int32x4_t b2 = vld1q_s32(bias_data + 8);
1798 const int32x4_t b3 = vld1q_s32(bias_data + 12);
1799 for (; i < num_output_pixels; i++)
1801 vst1q_s32(acc_buffer + 16 * i + 0, b0);
1802 vst1q_s32(acc_buffer + 16 * i + 4, b1);
1803 vst1q_s32(acc_buffer + 16 * i + 8, b2);
1804 vst1q_s32(acc_buffer + 16 * i + 12, b3);
1808 for (; i < num_output_pixels; i++)
1810 memcpy(acc_buffer + i * output_depth, bias_data,
sizeof(acc_buffer[0]) * output_depth);
1815 const uint8_t *input_data,
const Shape &filter_shape,
1816 const uint8_t *filter_data,
const Shape &bias_shape,
1818 uint8_t *output_data,
int thread_start,
int thread_end,
1838 const int input_height = input_shape.
Dims(1);
1839 const int input_width = input_shape.
Dims(2);
1840 const int input_depth = input_shape.
Dims(3);
1841 const int filter_height = filter_shape.
Dims(1);
1842 const int filter_width = filter_shape.
Dims(2);
1846 const bool shift_left = (output_shift > 0);
1847 const int32_t multiplier_power_of_two = shift_left ? (1 << output_shift) : 1;
1850 static const int kAccBufferMaxSize = 2048;
1851 int32_t acc_buffer[kAccBufferMaxSize];
1852 assert(kAccBufferMaxSize >= output_depth);
1853 const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth;
1854 [[maybe_unused]]
const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth;
1855 assert(kOutputPixelsInAccBuffer * output_depth <= kAccBufferActualSize);
1856 assert(kAccBufferActualSize <= kAccBufferMaxSize);
1857 assert(kOutputPixelsInAccBuffer >= 1);
1858 assert(thread_dim == 0 || thread_dim == 1);
1863 row_accum_func_t row_accum_func =
nullptr;
1865#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, FIXED_DEPTH_MULTIPLIER) \
1866 if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
1867 (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
1868 depth_multiplier == FIXED_DEPTH_MULTIPLIER) \
1871 QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, FIXED_DEPTH_MULTIPLIER>; \
1913 if (!row_accum_func)
1918#undef TFMINI_USE_DEPTHWISECONV_KERNEL
1920 const int input_height_stride = input_shape.
Dims(3) * input_shape.
Dims(2);
1921 const int input_batch_stride = input_height_stride * input_shape.
Dims(1);
1922 const int filter_height_stride = filter_shape.
Dims(3) * filter_shape.
Dims(2);
1925 int batch_start = 0;
1926 int batch_end = batches;
1928 int row_end = output_height;
1929 int output_ptr_offset = 0;
1935 assert(thread_start >= 0);
1936 assert(thread_end <= batches);
1937 batch_start = thread_start;
1938 batch_end = thread_end;
1943 assert(thread_start >= 0);
1944 assert(thread_end <= output_height);
1945 row_start = thread_start;
1946 row_end = thread_end;
1947 output_ptr_offset = row_start * output_width * output_depth;
1951 uint8_t *output_ptr = output_data + output_ptr_offset;
1952 int batch_step = (output_height + row_start - row_end) * output_width * output_depth;
1953 for (
int b = batch_start; b < batch_end; ++b)
1955 for (
int out_y = row_start; out_y < row_end; ++out_y)
1957 const int in_y_origin = (out_y * stride_height) - pad_height;
1958 const int filter_y_start =
1959 std::max(0, (-in_y_origin + dilation_height_factor - 1) / dilation_height_factor);
1960 const int filter_y_end =
1961 std::min(filter_height, (input_height - in_y_origin + dilation_height_factor - 1) /
1962 dilation_height_factor);
1963 for (
int out_x_buffer_start = 0; out_x_buffer_start < output_width;
1964 out_x_buffer_start += kOutputPixelsInAccBuffer)
1966 const int out_x_buffer_end =
1967 std::min(output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
1971 const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
1976 for (
int filter_y = filter_y_start; filter_y < filter_y_end; ++filter_y)
1978 const int in_y = in_y_origin + dilation_height_factor * filter_y;
1979 row_accum_func(stride_width, dilation_width_factor, input_depth, input_width,
1980 input_data + in_y * input_height_stride + b * input_batch_stride,
1981 input_offset, pad_width, depth_multiplier, filter_width,
1982 filter_data + filter_y * filter_height_stride, filter_offset,
1983 out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer);
1987 const int num_output_values = output_depth * num_output_pixels;
1990 using gemmlowp::RoundingDivideByPOT;
1991 const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
1992 const int32x4_t output_activation_min_vec = vdupq_n_s32(output_activation_min);
1993 const int32x4_t output_activation_max_vec = vdupq_n_s32(output_activation_max);
1998 for (; i <= num_output_values - 16; i += 16)
2001 for (
int j = 0; j < 4; j++)
2003 acc[j] = vld1q_s32(acc_buffer + i + 4 * j);
2009 for (
int j = 0; j < 4; j++)
2011 acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
2013 for (
int j = 0; j < 4; j++)
2015 acc[j] = RoundingDivideByPOT(acc[j], -output_shift);
2021 for (
int j = 0; j < 4; j++)
2023 acc[j] = vmulq_n_s32(acc[j], multiplier_power_of_two);
2024 acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
2028 for (
int j = 0; j < 4; j++)
2030 acc[j] = vaddq_s32(acc[j], output_offset_vec);
2033 for (
int j = 0; j < 4; j++)
2035 acc[j] = vmaxq_s32(acc[j], output_activation_min_vec);
2037 for (
int j = 0; j < 4; j++)
2039 acc[j] = vminq_s32(acc[j], output_activation_max_vec);
2042 int16x4_t acc_s16[4];
2043 for (
int j = 0; j < 4; j++)
2045 acc_s16[j] = vqmovn_s32(acc[j]);
2047 const int16x8_t res_s16_0 = vcombine_s16(acc_s16[0], acc_s16[1]);
2048 const int16x8_t res_s16_1 = vcombine_s16(acc_s16[2], acc_s16[3]);
2049 const uint8x8_t res_u8_0 = vqmovun_s16(res_s16_0);
2050 const uint8x8_t res_u8_1 = vqmovun_s16(res_s16_1);
2051 vst1q_u8(output_ptr, vcombine_u8(res_u8_0, res_u8_1));
2058 for (; i <= num_output_values - 8; i += 8)
2060 int32x4_t acc0 = vld1q_s32(acc_buffer + i);
2061 int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4);
2065 acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
2066 acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
2068 acc0 = RoundingDivideByPOT(acc0, -output_shift);
2069 acc1 = RoundingDivideByPOT(acc1, -output_shift);
2074 acc0 = vmulq_n_s32(acc0, multiplier_power_of_two);
2075 acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
2077 acc1 = vmulq_n_s32(acc1, multiplier_power_of_two);
2078 acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
2081 acc0 = vaddq_s32(acc0, output_offset_vec);
2082 acc1 = vaddq_s32(acc1, output_offset_vec);
2084 acc0 = vmaxq_s32(acc0, output_activation_min_vec);
2085 acc1 = vmaxq_s32(acc1, output_activation_min_vec);
2086 acc0 = vminq_s32(acc0, output_activation_max_vec);
2087 acc1 = vminq_s32(acc1, output_activation_max_vec);
2089 const int16x4_t acc0_s16 = vqmovn_s32(acc0);
2090 const int16x4_t acc1_s16 = vqmovn_s32(acc1);
2091 const int16x8_t res_s16 = vcombine_s16(acc0_s16, acc1_s16);
2092 const uint8x8_t res_u8 = vqmovun_s16(res_s16);
2093 vst1_u8(output_ptr, res_u8);
2101 for (; i <= num_output_values - 4; i += 4)
2103 int32x4_t acc = vld1q_s32(acc_buffer + i);
2107 acc = vqrdmulhq_n_s32(acc, output_multiplier);
2109 acc = RoundingDivideByPOT(acc, -output_shift);
2114 acc = vmulq_n_s32(acc, multiplier_power_of_two);
2115 acc = vqrdmulhq_n_s32(acc, output_multiplier);
2118 acc = vaddq_s32(acc, output_offset_vec);
2120 acc = vmaxq_s32(acc, output_activation_min_vec);
2121 acc = vminq_s32(acc, output_activation_max_vec);
2123 const int16x4_t acc_s16 = vqmovn_s32(acc);
2124 const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16);
2125 const uint8x8_t res_u8 = vqmovun_s16(res_s16);
2126 vst1_lane_u8(output_ptr + 0, res_u8, 0);
2127 vst1_lane_u8(output_ptr + 1, res_u8, 1);
2128 vst1_lane_u8(output_ptr + 2, res_u8, 2);
2129 vst1_lane_u8(output_ptr + 3, res_u8, 3);
2135 for (; i < num_output_values; i++)
2137 int32_t acc = acc_buffer[i];
2139 acc += output_offset;
2140 acc = std::max(acc, output_activation_min);
2141 acc = std::min(acc, output_activation_max);
2142 *output_ptr++ =
static_cast<uint8_t
>(acc);
2146 output_ptr += batch_step;