18#ifndef __NNFW_CKER_NEON_TENSOR_UTILS_H__
19#define __NNFW_CKER_NEON_TENSOR_UTILS_H__
27#if defined __linux__ && defined __aarch64__
36#define kFloatWeightsPerNeonLane 4
46constexpr int kFloatValuesPerNeonVector = 4;
49using int8 = std::int8_t;
50using uint8 = std::uint8_t;
51using int16 = std::int16_t;
52using uint16 = std::uint16_t;
53using int32 = std::int32_t;
54using uint32 = std::uint32_t;
56template <
int PerNeonSize>
inline int RoundDownVectors(
int size)
58 return size & ~(PerNeonSize - 1);
66void *aligned_alloc(
size_t alignment,
size_t size,
void **freeing_buffer)
68 *freeing_buffer = malloc(
size + alignment);
69 const size_t offset = ((uintptr_t)*freeing_buffer) % alignment;
70 return offset == 0 ? *freeing_buffer : ((
char *)*freeing_buffer + (alignment -
offset));
73inline int32_t AccumulateNeonLane(
const int32x4_t lane)
76 return vaddvq_s32(lane);
78 int64x2_t pairwiseAdded = vpaddlq_s32(lane);
79 return vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
89#if defined __linux__ && defined __aarch64__
90inline bool DetectDotprodByLinuxAuxvMethod()
95 const int kLocalHwcapAsimddp = 1 << 20;
96 return getauxval(AT_HWCAP) & kLocalHwcapAsimddp;
100inline bool DetectArmNeonDotprod()
102#if defined __linux__ && defined __aarch64__
103 return DetectDotprodByLinuxAuxvMethod();
109inline bool HasSdotInstruction()
111 static const bool has_dotprod = DetectArmNeonDotprod();
128inline const int8_t *ShuffleVectors(
const int8_t *vectors,
const int n_batch,
const int m_cols,
129 void **shuffled_vectors_free)
131 const int kWeightsPerUint32 = 4;
133 int8 *shuffled_vectors =
reinterpret_cast<int8 *
>(
134 aligned_alloc(kWeightsPerUint32, n_batch * m_cols, shuffled_vectors_free));
136 for (
int i = 0; i < n_batch; i += 4)
138 int8 *shuffled_vectors_ptr = shuffled_vectors + (i * m_cols);
139 const int8 *unshuffled_vec0_ptr =
reinterpret_cast<const int8 *
>(vectors) + (i * m_cols);
140 const int8 *unshuffled_vec1_ptr =
reinterpret_cast<const int8 *
>(vectors) + ((i + 1) * m_cols);
141 const int8 *unshuffled_vec2_ptr =
reinterpret_cast<const int8 *
>(vectors) + ((i + 2) * m_cols);
142 const int8 *unshuffled_vec3_ptr =
reinterpret_cast<const int8 *
>(vectors) + ((i + 3) * m_cols);
143 const int8 *
const end_vec0_ptr = unshuffled_vec1_ptr;
145 while (unshuffled_vec0_ptr != end_vec0_ptr)
150 "ld1 {v0.16b}, [%[unshuffled_vec0_ptr]], #16\n"
151 "ld1 {v1.16b}, [%[unshuffled_vec1_ptr]], #16\n"
152 "ld1 {v2.16b}, [%[unshuffled_vec2_ptr]], #16\n"
153 "ld1 {v3.16b}, [%[unshuffled_vec3_ptr]], #16\n"
155 "st4 {v0.s, v1.s, v2.s, v3.s}[0], [%[shuffled_vectors_ptr]], #16\n"
156 "st4 {v0.s, v1.s, v2.s, v3.s}[1], [%[shuffled_vectors_ptr]], #16\n"
157 "st4 {v0.s, v1.s, v2.s, v3.s}[2], [%[shuffled_vectors_ptr]], #16\n"
158 "st4 {v0.s, v1.s, v2.s, v3.s}[3], [%[shuffled_vectors_ptr]], #16\n"
160 : [unshuffled_vec0_ptr]
"+r"(unshuffled_vec0_ptr),
161 [unshuffled_vec1_ptr]
"+r"(unshuffled_vec1_ptr),
162 [unshuffled_vec2_ptr]
"+r"(unshuffled_vec2_ptr),
163 [unshuffled_vec3_ptr]
"+r"(unshuffled_vec3_ptr),
164 [shuffled_vectors_ptr]
"+r"(shuffled_vectors_ptr)
166 :
"v0",
"v1",
"v2",
"v3",
"cc",
"memory");
170 return reinterpret_cast<const int8_t *
>(shuffled_vectors);
182static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
const int8_t *__restrict__ matrix,
183 const int m_rows,
const int m_cols,
184 const int8_t *vectors,
185 const float *scaling_factors,
186 int n_batch,
float *__restrict__ result)
188 void *shuffled_vectors_free;
190 const int8_t *shuffled_vectors = ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
192 for (
int row = 0; row < m_rows; row += 2)
194 for (
int batch = 0; batch < n_batch; batch += 4)
196 float *result_ptr =
result + (batch * m_rows) + row;
197 const int8 *mat_ptr0 = matrix + (row * m_cols);
198 const int8 *mat_ptr1 = matrix + ((row + 1) * m_cols);
199 const int8 *mat_ptr0_end = mat_ptr1;
200 const int8 *vec_ptr = shuffled_vectors + (batch * m_cols);
201 const float *scaling_factors_ptr = scaling_factors + batch;
202 const uint64_t wide_rows = m_rows *
sizeof(float);
203 const int8 *mat_ptr2 = matrix + ((row + 2) * m_cols);
204 const int8 *mat_ptr3 = matrix + ((row + 3) * m_cols);
216 "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
219 "prfm pldl1strm, [%[mat_ptr2]]\n"
220 "prfm pldl1strm, [%[mat_ptr3]]\n"
230 "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
231 ".word 0x4f8ce100 // sdot v0.4s, v8.16b, v12.4b[0]\n"
232 "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
233 ".word 0x4face121 // sdot v1.4s, v9.16b, v12.4b[1]\n"
234 "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
235 ".word 0x4f8ce940 // sdot v0.4s, v10.16b, v12.4b[2]\n"
236 "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
237 ".word 0x4face961 // sdot v1.4s, v11.16b, v12.4b[3]\n"
240 "add %[mat_ptr2], %[mat_ptr2], #16\n"
241 "add %[mat_ptr3], %[mat_ptr3], #16\n"
244 "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
245 ".word 0x4f8de102 // sdot v2.4s, v8.16b, v13.4b[0]\n"
246 ".word 0x4fade123 // sdot v3.4s, v9.16b, v13.4b[1]\n"
247 ".word 0x4f8de942 // sdot v2.4s, v10.16b, v13.4b[2]\n"
248 ".word 0x4fade963 // sdot v3.4s, v11.16b, v13.4b[3]\n"
251 "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
255 "add v0.4s, v0.4s, v1.4s\n"
256 "add v2.4s, v2.4s, v3.4s\n"
259 "scvtf v0.4s, v0.4s\n"
260 "scvtf v1.4s, v2.4s\n"
263 "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
266 "fmul v0.4s, v4.4s, v0.4s\n"
267 "fmul v1.4s, v4.4s, v1.4s\n"
282 "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
283 "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
284 "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
285 "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
288 "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
291 "fadd v9.4s, v9.4s, v0.4s\n"
292 "fadd v10.4s, v10.4s, v1.4s\n"
295 "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
296 "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
297 "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
298 "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
299 : [mat_ptr0]
"+r"(mat_ptr0), [mat_ptr1]
"+r"(mat_ptr1), [vec_ptr]
"+r"(vec_ptr),
300 [result_ptr]
"+r"(result_ptr), [mat_ptr2]
"+r"(mat_ptr2), [mat_ptr3]
"+r"(mat_ptr3)
301 : [mat_ptr0_end]
"r"(mat_ptr0_end), [scaling_factors_ptr]
"r"(scaling_factors_ptr),
302 [wide_rows]
"r"(wide_rows)
303 :
"x0",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
304 "v13",
"cc",
"memory");
308 free(shuffled_vectors_free);
311static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
312 const int8_t *__restrict__ matrix,
const int m_rows,
const int m_cols,
const int8_t *vectors,
313 const float *scaling_factors,
int n_batch,
float *__restrict__ result,
314 const float *per_channel_scale,
const int32_t *input_offset, int32_t *row_sums)
316 void *shuffled_vectors_free;
317 const int8_t *shuffled_vectors = ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
319 for (
int row = 0; row < m_rows; row += 2)
321 const float *channel_scales_ptr = per_channel_scale + row;
322 int32_t *row_sums_ptr = row_sums ? row_sums + row :
nullptr;
323 for (
int batch = 0; batch < n_batch; batch += 4)
325 float *result_ptr =
result + (batch * m_rows) + row;
326 const int8 *mat_ptr0 = matrix + (row * m_cols);
327 const int8 *mat_ptr1 = matrix + ((row + 1) * m_cols);
328 const int8 *mat_ptr0_end = mat_ptr1;
329 const int8 *vec_ptr = shuffled_vectors + (batch * m_cols);
330 const float *scaling_factors_ptr = scaling_factors + batch;
331 const uint64_t wide_rows = m_rows *
sizeof(float);
332 const int32_t *batch_offsets_ptr = input_offset + batch;
333 const int32_t is_channel_scale_nullptr = per_channel_scale ==
nullptr;
334 const int32_t is_row_sums_nullptr = row_sums_ptr ==
nullptr;
341 "ld1 {v7.4s}, [%[batch_offsets_ptr]]\n"
342 "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
348 "cmp %w[is_channel_scale_nullptr], #0\n"
350 "ld1r {v16.4s}, [%[channel_scales_ptr]], #4\n"
351 "ld1r {v17.4s}, [%[channel_scales_ptr]]\n"
352 "fmul v16.4s, v16.4s, v4.4s\n"
353 "fmul v17.4s, v17.4s, v4.4s\n"
356 "mov v16.16b, v4.16b\n"
357 "mov v17.16b, v4.16b\n"
359 "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
360 "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
361 ".word 0x4f8ce100 // sdot v0.4s, v8.16b, v12.4b[0]\n"
362 "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
363 ".word 0x4face121 // sdot v1.4s, v9.16b, v12.4b[1]\n"
364 "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
365 ".word 0x4f8ce940 // sdot v0.4s, v10.16b, v12.4b[2]\n"
366 "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
367 ".word 0x4face961 // sdot v1.4s, v11.16b, v12.4b[3]\n"
368 "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
369 ".word 0x4f8de102 // sdot v2.4s, v8.16b, v13.4b[0]\n"
370 ".word 0x4fade123 // sdot v3.4s, v9.16b, v13.4b[1]\n"
371 ".word 0x4f8de942 // sdot v2.4s, v10.16b, v13.4b[2]\n"
372 ".word 0x4fade963 // sdot v3.4s, v11.16b, v13.4b[3]\n"
373 "cmp %w[is_row_sums_nullptr], #1\n"
376 "saddlp v12.8h, v12.16b\n"
377 "saddlp v13.8h, v13.16b\n"
378 "sadalp v14.4s, v12.8h\n"
379 "sadalp v15.4s, v13.8h\n"
381 "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
383 "add v0.4s, v0.4s, v1.4s\n"
384 "add v2.4s, v2.4s, v3.4s\n"
386 "cmp %w[is_row_sums_nullptr], #1\n"
391 "dup v14.4s, v14.s[0]\n"
392 "dup v15.4s, v15.s[0]\n"
395 "ld1r {v14.4s}, [%[row_sums_ptr]], #4\n"
396 "ld1r {v15.4s}, [%[row_sums_ptr]]\n"
399 "mul v14.4s, v14.4s, v7.4s\n"
400 "mul v15.4s, v15.4s, v7.4s\n"
401 "sub v0.4s, v0.4s, v14.4s\n"
402 "sub v2.4s, v2.4s, v15.4s\n"
404 "scvtf v0.4s, v0.4s\n"
405 "scvtf v1.4s, v2.4s\n"
408 "fmul v0.4s, v16.4s, v0.4s\n"
409 "fmul v1.4s, v17.4s, v1.4s\n"
411 "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
412 "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
413 "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
414 "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
415 "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
416 "fadd v9.4s, v9.4s, v0.4s\n"
417 "fadd v10.4s, v10.4s, v1.4s\n"
418 "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
419 "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
420 "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
421 "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
422 : [mat_ptr0]
"+r"(mat_ptr0), [mat_ptr1]
"+r"(mat_ptr1), [vec_ptr]
"+r"(vec_ptr),
423 [result_ptr]
"+r"(result_ptr), [row_sums_ptr]
"+r"(row_sums_ptr)
424 : [mat_ptr0_end]
"r"(mat_ptr0_end), [scaling_factors_ptr]
"r"(scaling_factors_ptr),
425 [wide_rows]
"r"(wide_rows), [channel_scales_ptr]
"r"(channel_scales_ptr),
426 [batch_offsets_ptr]
"r"(batch_offsets_ptr),
427 [is_channel_scale_nullptr]
"r"(is_channel_scale_nullptr),
428 [is_row_sums_nullptr]
"r"(is_row_sums_nullptr)
429 :
"x0",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
430 "v13",
"v14",
"v15",
"v16",
"v17",
"w0",
"w1",
"cc",
"memory");
434 free(shuffled_vectors_free);
460inline void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
461 const int8_t *__restrict__ matrix,
const int m_rows,
const int m_cols,
const int8_t *vectors,
462 const float *scaling_factors,
int n_batch,
float *__restrict__ result,
463 const float *per_channel_scale,
const int32_t *input_offset, int32_t *row_sums)
465 const int kWeightsPerUint32 = 4;
468 int batch_round_up = n_batch;
469 if (n_batch % 4 != 0)
471 batch_round_up += (4 - n_batch % 4);
473 assert(n_batch <= batch_round_up);
475 void *padded_vectors_free;
476 const int padded_vectors_size = batch_round_up * m_cols;
477 int8_t *padded_vectors =
reinterpret_cast<int8_t *
>(
478 aligned_alloc(kWeightsPerUint32, padded_vectors_size, &padded_vectors_free));
479 memset(padded_vectors, 0, padded_vectors_size);
481 void *padded_result_free;
482 const int result_size = n_batch * m_rows *
sizeof(float);
483 const int padded_result_size = batch_round_up * m_rows *
sizeof(float);
484 float *padded_result =
reinterpret_cast<float *
>(
485 aligned_alloc(kWeightsPerUint32, padded_result_size, &padded_result_free));
486 memcpy(padded_result, result, result_size);
487 memset(
reinterpret_cast<char *
>(padded_result) + result_size, 0,
488 padded_result_size - result_size);
491 assert(n_batch * m_cols <= padded_vectors_size);
492 memcpy(padded_vectors, vectors, n_batch * m_cols);
494 void *padded_scaling_factors_free;
495 const int padded_scaling_factors_size = batch_round_up *
sizeof(float);
496 float *padded_scaling_factors =
reinterpret_cast<float *
>(
497 aligned_alloc(kWeightsPerUint32, padded_scaling_factors_size, &padded_scaling_factors_free));
498 assert(
static_cast<int>(n_batch *
sizeof(
float)) <= padded_scaling_factors_size);
499 assert(
static_cast<int>(batch_round_up *
sizeof(
float)) <= padded_scaling_factors_size);
500 memset(padded_scaling_factors, 0, batch_round_up *
sizeof(
float));
501 memcpy(padded_scaling_factors, scaling_factors, n_batch *
sizeof(
float));
503 if (input_offset !=
nullptr)
505 void *padded_input_offset_free;
506 const int padded_input_offset_size = batch_round_up *
sizeof(int32_t);
507 int32_t *padded_input_offset =
reinterpret_cast<int32_t *
>(
508 aligned_alloc(kWeightsPerUint32, padded_input_offset_size, &padded_input_offset_free));
509 assert(
static_cast<int>(n_batch *
sizeof(int32_t)) <= padded_input_offset_size);
510 assert(
static_cast<int>(batch_round_up *
sizeof(int32_t)) <= padded_input_offset_size);
511 memset(padded_input_offset, 0, batch_round_up *
sizeof(int32_t));
512 memcpy(padded_input_offset, input_offset, n_batch *
sizeof(int32_t));
515 DotprodMatrixBatchFourVectorMultiplyAccumulate(
516 matrix, m_rows, m_cols, padded_vectors, padded_scaling_factors, batch_round_up, padded_result,
517 per_channel_scale, padded_input_offset, row_sums);
519 free(padded_input_offset_free);
524 DotprodMatrixBatchFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, padded_vectors,
525 padded_scaling_factors, batch_round_up,
528 memcpy(result, padded_result, result_size);
530 free(padded_result_free);
531 free(padded_vectors_free);
532 free(padded_scaling_factors_free);
535inline void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
536 const int8_t *__restrict__ matrix,
const int m_rows,
const int m_cols,
const int8_t *vectors,
537 const float *scaling_factors,
int n_batch,
float *__restrict__ result)
539 DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
540 matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
546inline void NeonCwiseClipping(
float *vector,
const int v_size,
const float clipping_value)
548 const float32x4_t clipping_value_f32x4 = vmovq_n_f32(clipping_value);
549 const float32x4_t neg_clipping_value_f32x4 = vmovq_n_f32(-clipping_value);
552 for (; i <= v_size - kFloatValuesPerNeonVector; i += kFloatValuesPerNeonVector)
555 float32x4_t v_f32x4 = vld1q_f32(vector + i);
557 v_f32x4 = vminq_f32(clipping_value_f32x4, v_f32x4);
558 v_f32x4 = vmaxq_f32(neg_clipping_value_f32x4, v_f32x4);
560 vst1q_f32(vector + i, v_f32x4);
562 for (; i < v_size; i++)
564 vector[i] = std::max(std::min(clipping_value, vector[i]), -clipping_value);
568inline bool NeonIsZeroVector(
const float *vector,
int v_size)
573 const int postamble_start = v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
575 const float32x4_t zero_x4_float = vmovq_n_f32(0.0f);
576 for (
int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane)
578 const float32x4_t i_x4_float = vld1q_f32(vector + v);
579 uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float);
580 if (vgetq_lane_u32(cmp_result, 0) == 0)
582 if (vgetq_lane_u32(cmp_result, 1) == 0)
584 if (vgetq_lane_u32(cmp_result, 2) == 0)
586 if (vgetq_lane_u32(cmp_result, 3) == 0)
591 for (
int v = postamble_start; v < v_size; ++v)
593 if (vector[v] != 0.0)
599inline void NeonCpuBackendGemm(
const int8_t *input,
const int32_t *bias,
600 const int8_t *input_to_gate_weights, int32_t n_batch,
601 int32_t n_input, int32_t n_output, int32_t, int32_t *scratch,
602 ruy::Context *ruy_context)
604 MatrixParams<int8_t> lhs_params;
606 lhs_params.rows = n_output;
607 lhs_params.cols = n_input;
610 MatrixParams<int8_t> rhs_params;
612 rhs_params.rows = n_input;
613 rhs_params.cols = n_batch;
615 MatrixParams<int32_t> dst_params;
617 dst_params.rows = n_output;
618 dst_params.cols = n_batch;
620 GemmParams<int32_t, int32_t> gemm_params;
623 gemm_params.bias =
bias;
627 ruy::Matrix<int8_t> ruy_lhs;
628 ruy::Matrix<int8_t> ruy_rhs;
629 ruy::Matrix<int32_t> ruy_dst;
635 ruy::MulParams<int32_t, int32_t> ruy_mul_params;
638 ruy::Mul(ruy_lhs, ruy_rhs, ruy_mul_params, ruy_context, &ruy_dst);
641inline void NeonSub1Vector(
const float *vector,
int v_size,
float *result)
646 const int postamble_start = RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
648 float32x4_t one_f32x4 = vmovq_n_f32(1.0);
650 for (; v < postamble_start; v += kFloatValuesPerNeonVector)
654 float32x4_t v_f32x4 = vld1q_f32(vector + v);
655 float32x4_t result_f32x4 = vsubq_f32(one_f32x4, v_f32x4);
657 vst1q_f32(result + v, result_f32x4);
659 for (; v < v_size; v++)
661 result[v] = 1.0f - vector[v];
665inline void NeonSymmetricQuantizeFloats(
const float *values,
const int size,
666 int8_t *quantized_values,
float *min,
float *max,
667 float *scaling_factor)
670 auto minmax = std::minmax_element(values, values +
size);
671 *min = *minmax.first;
672 *max = *minmax.second;
673 const int kScale = 127;
674 const float range = std::max(std::abs(*min), std::abs(*max));
677 memset(quantized_values, 0,
size *
sizeof(int8_t));
681 *scaling_factor = range / kScale;
682 const float scaling_factor_inv = kScale / range;
684 const int postamble_start =
size - (
size & (2 * kFloatWeightsPerNeonLane - 1));
687 const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
688 const float32x4_t point5_f32x4 = vmovq_n_f32(0.5);
689 const float32x4_t zero_f32x4 = vmovq_n_f32(0.0);
690 const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
691 const int32x4_t neg_scale_i32x4 = vmovq_n_s32(-kScale);
693 for (
int i = 0; i < postamble_start; i += 2 * kFloatWeightsPerNeonLane)
701 float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
702 float32x4_t value1_f32x4 = vld1q_f32(&values[i + kFloatWeightsPerNeonLane]);
703 float32x4_t mul0_f32x4 = vmulq_f32(value0_f32x4, q_factor_f32x4);
704 float32x4_t mul1_f32x4 = vmulq_f32(value1_f32x4, q_factor_f32x4);
706 int32x4_t cmp_with_zero0_ui32x4 = (int32x4_t)vcltq_f32(mul0_f32x4, zero_f32x4);
707 int32x4_t cmp_with_zero1_ui32x4 = (int32x4_t)vcltq_f32(mul1_f32x4, zero_f32x4);
709 float32x4_t cmp_with_zero0_f32x4 = vcvtq_f32_s32(cmp_with_zero0_ui32x4);
710 float32x4_t cmp_with_zero1_f32x4 = vcvtq_f32_s32(cmp_with_zero1_ui32x4);
711 cmp_with_zero0_f32x4 = vaddq_f32(cmp_with_zero0_f32x4, point5_f32x4);
712 cmp_with_zero1_f32x4 = vaddq_f32(cmp_with_zero1_f32x4, point5_f32x4);
714 mul0_f32x4 = vaddq_f32(mul0_f32x4, cmp_with_zero0_f32x4);
715 mul1_f32x4 = vaddq_f32(mul1_f32x4, cmp_with_zero1_f32x4);
717 int32x4_t f2i0_i32x4 = vcvtq_s32_f32(mul0_f32x4);
718 int32x4_t f2i1_i32x4 = vcvtq_s32_f32(mul1_f32x4);
723 int32x4_t max0_i32x4 = vmaxq_s32(f2i0_i32x4, neg_scale_i32x4);
724 int32x4_t max1_i32x4 = vmaxq_s32(f2i1_i32x4, neg_scale_i32x4);
725 int32x4_t min0_i32x4 = vminq_s32(max0_i32x4, scale_i32x4);
726 int32x4_t min1_i32x4 = vminq_s32(max1_i32x4, scale_i32x4);
728 int16x4_t min0_16x4 = vmovn_s32(min0_i32x4);
729 int16x4_t min1_16x4 = vmovn_s32(min1_i32x4);
731 int16x8_t min_16x8 = vcombine_s16(min0_16x4, min1_16x4);
732 int8x8_t min_s8x8 = vqmovn_s16(min_16x8);
733 vst1_s8(&quantized_values[i], min_s8x8);
736 for (
int i = postamble_start; i <
size; ++i)
738 const int32_t quantized_value =
739 static_cast<int32_t
>(std::round(scaling_factor_inv * values[i]));
740 quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
744inline void NeonMatrixBatchVectorMultiplyAccumulate(
const int8_t *__restrict__ matrix,
745 const int m_rows,
const int m_cols,
746 const int8_t *__restrict__ vectors,
747 const float *scaling_factors,
int n_batch,
748 float *__restrict__ result,
int result_stride)
751 if (HasSdotInstruction() && m_cols % 16 == 0 && m_rows % 2 == 0 && m_rows >= n_batch)
753 if (n_batch % 4 == 0 && result_stride == 1)
757 DotprodMatrixBatchFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
758 scaling_factors, n_batch, result);
761 else if (result_stride == 1 && n_batch >= 2 && m_rows * m_cols >= 128 * 128)
763 DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
764 scaling_factors, n_batch, result);
770 static const int kWeightsPerUint32 = 4;
771 static const int kWeightsPerNeonLane = 16;
782 bool unaligned =
false;
783 int8_t *aligned_row =
nullptr;
784 void *aligned_row_free =
nullptr;
785 if ((m_cols & (kWeightsPerUint32 - 1)) != 0)
788 aligned_row = (int8_t *)aligned_alloc(kWeightsPerUint32, m_cols,
791 void *aligned_vec_free =
nullptr;
792 int8_t *aligned_vec = (int8_t *)aligned_alloc(kWeightsPerUint32, m_cols,
800 const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1);
801 const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1);
803 for (
int batch = 0; batch < n_batch; ++batch)
805 const float batch_scaling_factor = scaling_factors[batch];
807 memcpy(aligned_vec, vectors + batch * m_cols,
sizeof(int8_t) * m_cols);
809 for (
int row = 0; row < m_rows; ++row,
result += result_stride)
812 int8_t *row_ptr = (int8_t *)matrix + row * m_cols;
815 memcpy(aligned_row, row_ptr,
sizeof(int8_t) * m_cols);
816 row_ptr = aligned_row;
820 int32x4_t dotprod_32x4 = vmovq_n_s32(0);
823 __builtin_prefetch(row_ptr, 0 , 3 );
827 for (; col < postamble_half_start; col += kWeightsPerNeonLane)
833 ((uintptr_t)(&row_ptr[col]) & (kWeightsPerUint32 - 1)) == 0);
834 const int8x16_t s1_8x16 = vld1q_s8((
const int8_t *)(aligned_vec + col));
835 const int8x16_t s2_8x16 = vld1q_s8((
const int8_t *)(row_ptr + col));
838 int16x8_t prod_16x8 = vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
845 prod_16x8 = vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
847 dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
852 if (col < postamble_start)
858 ((uintptr_t)(&row_ptr[col]) & (kWeightsPerUint32 - 1)) == 0);
859 const int8x8_t s1_8x8 = vld1_s8((
const int8_t *)(aligned_vec + col));
860 const int8x8_t s2_8x8 = vld1_s8((
const int8_t *)(row_ptr + col));
861 const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
862 dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
863 col += (kWeightsPerNeonLane >> 1);
867 int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
870 for (; col < m_cols; ++col)
872 dotprod += row_ptr[col] * aligned_vec[col];
875 *
result += dotprod * batch_scaling_factor;
881 free(aligned_row_free);
883 free(aligned_vec_free);
886inline void NeonMatrixBatchVectorMultiplyAccumulate(
const float *matrix,
int m_rows,
int m_cols,
887 const float *vector,
int n_batch,
float *result,
893 const int postamble_start = m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1));
895 for (
int b = 0;
b < n_batch;
b++)
897 float *result_in_batch =
result +
b * m_rows * result_stride;
898 const float *vector_in_batch = vector +
b * m_cols;
899 const float *matrix_row = matrix;
902 for (
int r = 0;
r < m_rows;
r++)
904 float32x4_t acc_32x4 = vmovq_n_f32(0.0);
905 for (
int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane)
908 float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c);
909 float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c);
911 acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
915 *result_in_batch += (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
916 vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
917 for (
int c = postamble_start; c < m_cols; c++)
919 *result_in_batch += matrix_row[c] * vector_in_batch[c];
921 matrix_row += m_cols;
922 result_in_batch += result_stride;
927inline void NeonMatrixBatchVectorMultiplyAccumulate(
const int8_t *__restrict__ matrix,
928 const int m_rows,
const int m_cols,
929 const int8_t *__restrict__ vectors,
930 const float *scaling_factors,
int n_batch,
931 int32_t *scratch,
float *__restrict__ result,
932 int result_stride, ruy::Context *ruy_context)
934 if (m_rows % 4 == 0 && result_stride == 1)
936 const int32_t *
bias =
static_cast<const int32_t *
>(
nullptr);
937 NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
938 0, scratch, ruy_context);
941 const int total_size = n_batch * m_rows;
943 for (; i <= total_size - 8; i += 8,
result += 8 * result_stride)
945 const float batch_scaling_factor0 = scaling_factors[i / m_rows];
946 const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
947 const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
948 const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
949 const int32x4_t scratch_val0 = vld1q_s32(scratch + i);
950 const int32x4_t scratch_val1 = vld1q_s32(scratch + i + 4);
951 const float32x4_t float_val0 = vcvtq_f32_s32(scratch_val0);
952 const float32x4_t float_val1 = vcvtq_f32_s32(scratch_val1);
953 const float32x4_t result0 = vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0);
954 const float32x4_t result1 =
955 vmlaq_f32(vld1q_f32(result + 4 * result_stride), float_val1, scaling_factor1);
956 vst1q_f32(result, result0);
957 vst1q_f32(result + 4 * result_stride, result1);
960 for (; i < total_size; i++,
result += result_stride)
962 const float batch_scaling_factor = scaling_factors[i / m_rows];
963 int32_t x = *(scratch++);
964 *
result += x * batch_scaling_factor;
968 NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, scaling_factors, n_batch,
969 result, result_stride);
__global uchar * offset(const Image *img, int x, int y)
void MakeRuyMulParams(const GemmParams< std::int32_t, DstScalar, quantization_flavor > ¶ms, ruy::MulParams< std::int32_t, DstScalar > *ruy_mul_params)
void MakeRuyMatrix(const MatrixParams< Scalar > ¶ms, DataPointer data_ptr, ruy::Matrix< Scalar > *dst, bool use_caching=false)