29 const int32_t *,
const int8_t *input_data,
30 const int32_t *filter_shape,
const int8_t *filter_data,
32 int8_t *output_data, uint32_t output_dims_count,
33 uint32_t weights_dims_count)
36 const int output_depth =
output_shape[output_dims_count - 1];
37 const int accum_depth = filter_shape[weights_dims_count - 1];
39 cmsis_nn_fc_params fc_params;
46 cmsis_nn_per_tensor_quant_params quant_params;
50 cmsis_nn_dims input_dims;
51 input_dims.n = batches;
54 input_dims.c = accum_depth;
56 cmsis_nn_dims filter_dims;
57 filter_dims.n = accum_depth;
60 filter_dims.c = output_depth;
62 cmsis_nn_dims bias_dims;
66 bias_dims.c = output_depth;
68 cmsis_nn_dims output_dims;
69 output_dims.n = batches;
72 output_dims.c = output_depth;
74 int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
75 auto buffer = std::make_unique<int8_t[]>(buf_size);
76 assert(buffer !=
nullptr);
79 ctx.buf = buffer.get();
83 arm_fully_connected_s8(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
84 filter_data, &bias_dims, bias_data, &output_dims, output_data);
85 assert(res == ARM_CMSIS_NN_SUCCESS);
90 const int32_t *,
const int16_t *input_data,
const int32_t *filter_shape,
91 const int8_t *filter_data,
const int64_t *bias_data,
93 uint32_t output_dims_count, uint32_t weights_dims_count)
96 const int output_depth =
output_shape[output_dims_count - 1];
97 const int accum_depth = filter_shape[weights_dims_count - 1];
99 cmsis_nn_fc_params fc_params;
106 cmsis_nn_per_tensor_quant_params quant_params;
110 cmsis_nn_dims input_dims;
111 input_dims.n = batches;
114 input_dims.c = accum_depth;
116 cmsis_nn_dims filter_dims;
117 filter_dims.n = accum_depth;
120 filter_dims.c = output_depth;
122 cmsis_nn_dims bias_dims;
126 bias_dims.c = output_depth;
128 cmsis_nn_dims output_dims;
129 output_dims.n = batches;
132 output_dims.c = output_depth;
134 int32_t buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims);
135 auto buffer = std::make_unique<int8_t[]>(buf_size);
136 assert(buffer !=
nullptr);
138 cmsis_nn_context ctx;
139 ctx.buf = buffer.get();
143 arm_fully_connected_s16(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
144 filter_data, &bias_dims, bias_data, &output_dims, output_data);
145 assert(res == ARM_CMSIS_NN_SUCCESS);
void FullyConnected< int8_t >(const tflite::FullyConnectedParams ¶ms, const tflite::RuntimeShape &input_shape, const int8_t *input_data, const tflite::RuntimeShape &filter_shape, const int8_t *filter_data, const tflite::RuntimeShape &bias_shape, const int32_t *bias_data, const tflite::RuntimeShape &output_shape, int8_t *output_data)