17#ifndef LUCI_INTERPRETER_PAL_CONV2D_H
18#define LUCI_INTERPRETER_PAL_CONV2D_H
20#include <tensorflow/lite/kernels/internal/reference/conv.h>
21#include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
22#include <arm_nn_types.h>
23#include <arm_nnfunctions.h>
27static inline void Conv(
const tflite::ConvParams ¶ms,
const tflite::RuntimeShape &input_shape,
28 const float *input_data,
const tflite::RuntimeShape &filter_shape,
29 const float *filter_data,
const tflite::RuntimeShape &bias_shape,
30 const float *bias_data,
const tflite::RuntimeShape &
output_shape,
31 float *output_data,
const tflite::RuntimeShape &scratchpad_shape,
32 float *scratchpad_data)
34 (void)scratchpad_shape;
35 (void)scratchpad_data;
36 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
38 tflite::RuntimeShape(),
nullptr);
41static inline void Conv(
const tflite::ConvParams ¶ms,
const tflite::RuntimeShape &input_shape,
42 const uint8 *input_data,
const tflite::RuntimeShape &filter_shape,
43 const uint8 *filter_data,
const tflite::RuntimeShape &bias_shape,
45 uint8 *output_data,
const tflite::RuntimeShape &scratchpad_shape,
46 uint8 *scratchpad_data)
48 (void)scratchpad_shape;
49 (void)scratchpad_data;
50 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
51 bias_shape, bias_data,
output_shape, output_data, scratchpad_shape,
52 scratchpad_data,
nullptr);
55static inline void ConvPerChannel(
const tflite::ConvParams ¶ms,
const int32_t *mult,
56 const int32_t *shifts,
const tflite::RuntimeShape &input_shape,
57 const int8 *input_data,
const tflite::RuntimeShape &filter_shape,
58 const int8 *filter_data,
const tflite::RuntimeShape &bias_shape,
60 int8 *output_data,
const tflite::RuntimeShape &scratchpad_shape,
61 int8 *scratchpad_data)
65 cmsis_nn_conv_params conv_params;
66 conv_params.dilation.h = params.dilation_height_factor;
67 conv_params.dilation.w = params.dilation_width_factor;
69 assert(conv_params.dilation.h == 1);
70 assert(conv_params.dilation.w == 1);
72 conv_params.input_offset = params.input_offset;
73 conv_params.output_offset = params.output_offset;
74 conv_params.stride.h = params.stride_height;
75 conv_params.stride.w = params.stride_width;
76 conv_params.padding.h = params.padding_values.height;
77 conv_params.padding.w = params.padding_values.width;
78 conv_params.activation.min = params.quantized_activation_min;
79 conv_params.activation.max = params.quantized_activation_max;
81 cmsis_nn_per_channel_quant_params quant_params;
82 quant_params.multiplier =
const_cast<int32_t *
>(
mult);
83 quant_params.shift =
const_cast<int32_t *
>(shifts);
85 assert(conv_params.activation.min <= conv_params.activation.max);
86 assert(input_shape.DimensionsCount() == 4);
87 assert(filter_shape.DimensionsCount() == 4);
89 const int batch_size = tflite::MatchingDim(input_shape, 0,
output_shape, 0);
90 const int input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
91 const int output_depth = tflite::MatchingDim(filter_shape, 0,
output_shape, 3);
94 assert(bias_shape.FlatSize() == output_depth);
97 cmsis_nn_dims input_dims;
98 input_dims.n = batch_size;
99 input_dims.h = input_shape.Dims(1);
100 input_dims.w = input_shape.Dims(2);
101 input_dims.c = input_depth;
103 cmsis_nn_dims filter_dims;
104 filter_dims.n = output_depth;
105 filter_dims.h = filter_shape.Dims(1);
106 filter_dims.w = filter_shape.Dims(2);
107 filter_dims.c = input_depth;
109 cmsis_nn_dims bias_dims;
113 bias_dims.c = output_depth;
115 cmsis_nn_dims output_dims;
116 output_dims.n = batch_size;
119 output_dims.c = output_depth;
121 cmsis_nn_context ctx;
122 ctx.buf = scratchpad_data;
123 ctx.size = scratchpad_shape.Dims(0);
125 auto res = arm_convolve_wrapper_s8(&ctx, &conv_params, &quant_params, &input_dims, input_data,
126 &filter_dims, filter_data, &bias_dims, bias_data,
127 &output_dims, output_data);
128 assert(res == ARM_MATH_SUCCESS);
132 tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
133 filter_shape, filter_data, bias_shape, bias_data,
140 const tflite::ConvParams ¶ms,
141 const tflite::RuntimeShape &input_shape,
142 const tflite::RuntimeShape &filter_shape,
145 cmsis_nn_conv_params conv_params;
146 conv_params.dilation.h = params.dilation_height_factor;
147 conv_params.dilation.w = params.dilation_width_factor;
149 if (input_data_type == loco::DataType::S8 && conv_params.dilation.h == 1 &&
150 conv_params.dilation.w == 1)
152 const int32_t batches = tflite::MatchingDim(input_shape, 0,
output_shape, 0);
153 const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
154 const int32_t output_depth = tflite::MatchingDim(filter_shape, 0,
output_shape, 3);
155 const int32_t filter_height = filter_shape.Dims(1);
156 const int32_t filter_width = filter_shape.Dims(2);
160 conv_params.input_offset = params.input_offset;
161 conv_params.output_offset = params.output_offset;
162 conv_params.stride.h = params.stride_height;
163 conv_params.stride.w = params.stride_width;
164 conv_params.padding.h = params.padding_values.height;
165 conv_params.padding.w = params.padding_values.width;
167 cmsis_nn_dims input_dims;
168 input_dims.n = batches;
169 input_dims.h = input_shape.Dims(1);
170 input_dims.w = input_shape.Dims(2);
171 input_dims.c = input_depth;
173 cmsis_nn_dims filter_dims;
174 filter_dims.n = output_depth;
175 filter_dims.h = filter_height;
176 filter_dims.w = filter_width;
177 filter_dims.c = input_depth;
179 cmsis_nn_dims output_dims;
180 output_dims.n = batches;
181 output_dims.h = output_height;
182 output_dims.w = output_width;
183 output_dims.c = output_depth;
185 const int32_t buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
186 &filter_dims, &output_dims);
189 scratchpad->
resize(scratchpad_shape);
void set_allocatable(bool value)
void resize(const Shape &new_shape)
const luci_interpreter::RuntimeShape output_shape
DataType
"scalar" value type
void Conv(const ConvParams ¶ms, const Shape &input_shape, const uint8_t *input_data, const Shape &filter_shape, const uint8_t *filter_data, const Shape &bias_shape, const int32_t *bias_data, const Shape &output_shape, uint8_t *output_data, const Shape &im2col_shape, uint8_t *im2col_data)