17#ifndef LUCI_INTERPRETER_PAL_CONV2D_H
18#define LUCI_INTERPRETER_PAL_CONV2D_H
20#include "PALConv2DCommon.h"
22#include <arm_nnfunctions.h>
27static inline void QuantizedConvPerChannel(
const ConvParams ¶ms,
const int32_t *input_shape,
28 const int8_t *input_data,
const int32_t *filter_shape,
29 const int8_t *filter_data,
const int32_t *bias_data,
32 cmsis_nn_conv_params conv_params;
33 conv_params.dilation.h = params.dilation_height_factor;
34 conv_params.dilation.w = params.dilation_width_factor;
36 assert(conv_params.dilation.h == 1);
37 assert(conv_params.dilation.w == 1);
39 conv_params.input_offset = params.input_offset;
40 conv_params.output_offset = params.output_offset;
41 conv_params.stride.h = params.stride_height;
42 conv_params.stride.w = params.stride_width;
43 conv_params.padding.h = params.padding_values.height;
44 conv_params.padding.w = params.padding_values.width;
45 conv_params.activation.min = params.quantized_activation_min;
46 conv_params.activation.max = params.quantized_activation_max;
48 cmsis_nn_per_channel_quant_params quant_params;
49 quant_params.multiplier =
const_cast<int32_t *
>(params.per_channel_output_multiplier.data());
50 quant_params.shift =
const_cast<int32_t *
>(
51 reinterpret_cast<const int32_t *
>(params.per_channel_output_shift.data()));
53 assert(conv_params.activation.min <= conv_params.activation.max);
54 const int batch_size = input_shape[0];
55 const int input_depth = input_shape[3];
56 const int output_depth = filter_shape[0];
58 cmsis_nn_dims input_dims;
59 input_dims.n = batch_size;
60 input_dims.h = input_shape[1];
61 input_dims.w = input_shape[2];
62 input_dims.c = input_depth;
64 cmsis_nn_dims filter_dims;
65 filter_dims.n = output_depth;
66 filter_dims.h = filter_shape[1];
67 filter_dims.w = filter_shape[2];
68 filter_dims.c = input_depth;
70 cmsis_nn_dims bias_dims;
74 bias_dims.c = output_depth;
76 cmsis_nn_dims output_dims;
77 output_dims.n = batch_size;
80 output_dims.c = output_depth;
83 arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
85 auto buffer = std::make_unique<int8_t[]>(buf_size);
86 assert(buffer !=
nullptr);
89 ctx.buf = buffer.get();
92 auto res = arm_convolve_wrapper_s8(&ctx, &conv_params, &quant_params, &input_dims, input_data,
93 &filter_dims, filter_data, &bias_dims, bias_data, &output_dims,
96 assert(res == ARM_CMSIS_NN_SUCCESS);
const luci_interpreter::RuntimeShape output_shape