42 const int accum_depth = filter_shape.
dims(filter_dim_count - 1);
44 cmsis_nn_fc_params fc_params;
51 cmsis_nn_per_tensor_quant_params quant_params;
55 cmsis_nn_dims input_dims;
56 input_dims.n = batches;
59 input_dims.c = accum_depth;
61 cmsis_nn_dims filter_dims;
62 filter_dims.n = accum_depth;
65 filter_dims.c = output_depth;
67 cmsis_nn_dims bias_dims;
71 bias_dims.c = output_depth;
73 cmsis_nn_dims output_dims;
74 output_dims.n = batches;
77 output_dims.c = output_depth;
79 int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
80 auto buffer = std::make_unique<int8_t[]>(buf_size);
81 assert(buffer !=
nullptr);
84 ctx.buf = buffer.get();
88 arm_fully_connected_s8(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
89 filter_data, &bias_dims, bias_data, &output_dims, output_data);
90 assert(res == ARM_CMSIS_NN_SUCCESS);
91 if (res != ARM_CMSIS_NN_SUCCESS)
101 int16_t *output_data)
108 const int accum_depth = filter_shape.
dims(filter_dim_count - 1);
110 cmsis_nn_fc_params fc_params;
117 cmsis_nn_per_tensor_quant_params quant_params;
121 cmsis_nn_dims input_dims;
122 input_dims.n = batches;
125 input_dims.c = accum_depth;
127 cmsis_nn_dims filter_dims;
128 filter_dims.n = accum_depth;
131 filter_dims.c = output_depth;
133 cmsis_nn_dims bias_dims;
137 bias_dims.c = output_depth;
139 cmsis_nn_dims output_dims;
140 output_dims.n = batches;
143 output_dims.c = output_depth;
145 int32_t buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims);
146 auto buffer = std::make_unique<int8_t[]>(buf_size);
147 assert(buffer !=
nullptr);
149 cmsis_nn_context ctx;
150 ctx.buf = buffer.get();
154 arm_fully_connected_s16(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
155 filter_data, &bias_dims, bias_data, &output_dims, output_data);
156 assert(res == ARM_CMSIS_NN_SUCCESS);
158 if (res != ARM_CMSIS_NN_SUCCESS)
OMStatus FullyConnected< int8_t >(const core::FullyConnectedParams ¶ms, const int8_t *input_data, const core::OMRuntimeShape &filter_shape, const int8_t *filter_data, const int32_t *bias_data, const core::OMRuntimeShape &output_shape, int8_t *output_data)
OMStatus FullyConnected(const core::FullyConnectedParams ¶ms, const int16_t *input_data, const core::OMRuntimeShape &filter_shape, const int8_t *filter_data, const int64_t *bias_data, const core::OMRuntimeShape &output_shape, int16_t *output_data)