51 const tflite::RuntimeShape &input_shape,
const int8_t *input_data,
52 const tflite::RuntimeShape &filter_shape,
const int8_t *filter_data,
53 const tflite::RuntimeShape &bias_shape,
const int32_t *bias_data,
54 const tflite::RuntimeShape &
output_shape, int8_t *output_data)
61 const int filter_dim_count = filter_shape.DimensionsCount();
62 const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
64 cmsis_nn_fc_params fc_params;
65 fc_params.input_offset = params.input_offset;
66 fc_params.output_offset = params.output_offset;
67 fc_params.filter_offset = params.weights_offset;
68 fc_params.activation.min = params.quantized_activation_min;
69 fc_params.activation.max = params.quantized_activation_max;
71 cmsis_nn_per_tensor_quant_params quant_params;
72 quant_params.multiplier = params.output_multiplier;
73 quant_params.shift = params.output_shift;
75 cmsis_nn_dims input_dims;
76 input_dims.n = batches;
79 input_dims.c = accum_depth;
81 cmsis_nn_dims filter_dims;
82 filter_dims.n = accum_depth;
85 filter_dims.c = output_depth;
87 cmsis_nn_dims bias_dims;
91 bias_dims.c = output_depth;
93 cmsis_nn_dims output_dims;
94 output_dims.n = batches;
97 output_dims.c = output_depth;
99 int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
100 auto buffer = std::make_unique<int8_t[]>(buf_size);
101 assert(buffer !=
nullptr);
103 cmsis_nn_context ctx;
104 ctx.buf = buffer.get();
108 arm_fully_connected_s8(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
109 filter_data, &bias_dims, bias_data, &output_dims, output_data);
110 assert(res == ARM_MATH_SUCCESS);
void FullyConnected< int8_t >(const tflite::FullyConnectedParams ¶ms, const tflite::RuntimeShape &input_shape, const int8_t *input_data, const tflite::RuntimeShape &filter_shape, const int8_t *filter_data, const tflite::RuntimeShape &bias_shape, const int32_t *bias_data, const tflite::RuntimeShape &output_shape, int8_t *output_data)
void FullyConnected(const FullyConnectedParams ¶ms, const Shape &input_shape, const float *input_data, const Shape &weights_shape, const float *weights_data, const Shape &, const float *bias_data, const Shape &, float *output_data)