38 const int8_t *filter_data,
const int32_t *bias_data,
41 cmsis_nn_conv_params conv_params;
45 assert(conv_params.dilation.h == 1);
46 assert(conv_params.dilation.w == 1);
52 conv_params.padding.h = params.
pad_h;
53 conv_params.padding.w = params.
pad_w;
57 cmsis_nn_per_channel_quant_params quant_params;
59 quant_params.shift =
const_cast<int32_t *
>(
62 assert(conv_params.activation.min <= conv_params.activation.max);
63 const int batch_size = input_shape.
dims(0);
64 const int input_depth = input_shape.
dims(3);
65 const int output_depth = filter_shape.
dims(0);
67 cmsis_nn_dims input_dims;
68 input_dims.n = batch_size;
69 input_dims.h = input_shape.
dims(1);
70 input_dims.w = input_shape.
dims(2);
71 input_dims.c = input_depth;
73 cmsis_nn_dims filter_dims;
74 filter_dims.n = output_depth;
75 filter_dims.h = filter_shape.
dims(1);
76 filter_dims.w = filter_shape.
dims(2);
77 filter_dims.c = input_depth;
79 cmsis_nn_dims bias_dims;
83 bias_dims.c = output_depth;
85 cmsis_nn_dims output_dims;
86 output_dims.n = batch_size;
89 output_dims.c = output_depth;
92 arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
94 auto buffer = std::make_unique<int8_t[]>(buf_size);
95 assert(buffer !=
nullptr);
98 ctx.buf = buffer.get();
101 auto res = arm_convolve_wrapper_s8(&ctx, &conv_params, &quant_params, &input_dims, input_data,
102 &filter_dims, filter_data, &bias_dims, bias_data, &output_dims,
105 assert(res == ARM_CMSIS_NN_SUCCESS);
106 if (res != ARM_CMSIS_NN_SUCCESS)
OMStatus ConvPerChannel(const core::ConvQuant ¶ms, const core::OMRuntimeShape &input_shape, const int8_t *input_data, const core::OMRuntimeShape &filter_shape, const int8_t *filter_data, const int32_t *bias_data, const core::OMRuntimeShape &output_shape, int8_t *output_data)