ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALFullyConnected.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef ONERT_MICRO_EXECUTE_PAL_FULLY_CONNECTED_H
19#define ONERT_MICRO_EXECUTE_PAL_FULLY_CONNECTED_H
20
21#include "PALFullyConnectedCommon.h"
22
23#include <arm_nnfunctions.h>
24
25namespace onert_micro
26{
27namespace execute
28{
29namespace pal
30{
31template <>
32OMStatus FullyConnected<int8_t>(const core::FullyConnectedParams &params, const int8_t *input_data,
33 const core::OMRuntimeShape &filter_shape, const int8_t *filter_data,
34 const int32_t *bias_data, const core::OMRuntimeShape &output_shape,
35 int8_t *output_data)
36{
37 const int filter_dim_count = filter_shape.dimensionsCount();
38 const int output_dim_count = output_shape.dimensionsCount();
39 const int batches =
40 flatSizeSkipDim(output_shape.dimsData(), output_dim_count - 1, output_dim_count);
41 const int output_depth = output_shape.dims(output_dim_count - 1);
42 const int accum_depth = filter_shape.dims(filter_dim_count - 1);
43
44 cmsis_nn_fc_params fc_params;
45 fc_params.input_offset = params.input_offset;
46 fc_params.output_offset = params.output_offset;
47 fc_params.filter_offset = params.weights_offset;
48 fc_params.activation.min = params.quantized_activation_min;
49 fc_params.activation.max = params.quantized_activation_max;
50
51 cmsis_nn_per_tensor_quant_params quant_params;
52 quant_params.multiplier = params.output_multiplier;
53 quant_params.shift = params.output_shift;
54
55 cmsis_nn_dims input_dims;
56 input_dims.n = batches;
57 input_dims.h = 1;
58 input_dims.w = 1;
59 input_dims.c = accum_depth;
60
61 cmsis_nn_dims filter_dims;
62 filter_dims.n = accum_depth;
63 filter_dims.h = 1;
64 filter_dims.w = 1;
65 filter_dims.c = output_depth;
66
67 cmsis_nn_dims bias_dims;
68 bias_dims.n = 1;
69 bias_dims.h = 1;
70 bias_dims.w = 1;
71 bias_dims.c = output_depth;
72
73 cmsis_nn_dims output_dims;
74 output_dims.n = batches;
75 output_dims.h = 1;
76 output_dims.w = 1;
77 output_dims.c = output_depth;
78
79 int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
80 auto buffer = std::make_unique<int8_t[]>(buf_size);
81 assert(buffer != nullptr);
82
83 cmsis_nn_context ctx;
84 ctx.buf = buffer.get();
85 ctx.size = buf_size;
86
87 auto res =
88 arm_fully_connected_s8(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
89 filter_data, &bias_dims, bias_data, &output_dims, output_data);
90 assert(res == ARM_CMSIS_NN_SUCCESS);
91 if (res != ARM_CMSIS_NN_SUCCESS)
92 return CmsisNNError;
93
94 return Ok;
95}
96
97template <>
98OMStatus FullyConnected(const core::FullyConnectedParams &params, const int16_t *input_data,
99 const core::OMRuntimeShape &filter_shape, const int8_t *filter_data,
100 const int64_t *bias_data, const core::OMRuntimeShape &output_shape,
101 int16_t *output_data)
102{
103 const int filter_dim_count = filter_shape.dimensionsCount();
104 const int output_dim_count = output_shape.dimensionsCount();
105 const int batches =
106 flatSizeSkipDim(output_shape.dimsData(), output_dim_count - 1, output_dim_count);
107 const int output_depth = output_shape.dims(output_dim_count - 1);
108 const int accum_depth = filter_shape.dims(filter_dim_count - 1);
109
110 cmsis_nn_fc_params fc_params;
111 fc_params.input_offset = params.input_offset;
112 fc_params.output_offset = params.output_offset;
113 fc_params.filter_offset = params.weights_offset;
114 fc_params.activation.min = params.quantized_activation_min;
115 fc_params.activation.max = params.quantized_activation_max;
116
117 cmsis_nn_per_tensor_quant_params quant_params;
118 quant_params.multiplier = params.output_multiplier;
119 quant_params.shift = params.output_shift;
120
121 cmsis_nn_dims input_dims;
122 input_dims.n = batches;
123 input_dims.h = 1;
124 input_dims.w = 1;
125 input_dims.c = accum_depth;
126
127 cmsis_nn_dims filter_dims;
128 filter_dims.n = accum_depth;
129 filter_dims.h = 1;
130 filter_dims.w = 1;
131 filter_dims.c = output_depth;
132
133 cmsis_nn_dims bias_dims;
134 bias_dims.n = 1;
135 bias_dims.h = 1;
136 bias_dims.w = 1;
137 bias_dims.c = output_depth;
138
139 cmsis_nn_dims output_dims;
140 output_dims.n = batches;
141 output_dims.h = 1;
142 output_dims.w = 1;
143 output_dims.c = output_depth;
144
145 int32_t buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims);
146 auto buffer = std::make_unique<int8_t[]>(buf_size);
147 assert(buffer != nullptr);
148
149 cmsis_nn_context ctx;
150 ctx.buf = buffer.get();
151 ctx.size = buf_size;
152
153 auto res =
154 arm_fully_connected_s16(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
155 filter_data, &bias_dims, bias_data, &output_dims, output_data);
156 assert(res == ARM_CMSIS_NN_SUCCESS);
157
158 if (res != ARM_CMSIS_NN_SUCCESS)
159 return CmsisNNError;
160
161 return Ok;
162}
163
164} // namespace pal
165} // namespace execute
166} // namespace onert_micro
167
168#endif // ONERT_MICRO_EXECUTE_PAL_FULLY_CONNECTED_COMMON_H
int32_t dimensionsCount() const
Definition Tensor.h:106
int32_t dims(int i) const
Definition Tensor.h:108
const luci_interpreter::RuntimeShape output_shape
int flatSizeSkipDim(const int32_t *dims_data, int skip_dim, int num_dims)
Definition PALUtils.h:210
OMStatus FullyConnected< int8_t >(const core::FullyConnectedParams &params, const int8_t *input_data, const core::OMRuntimeShape &filter_shape, const int8_t *filter_data, const int32_t *bias_data, const core::OMRuntimeShape &output_shape, int8_t *output_data)
OMStatus FullyConnected(const core::FullyConnectedParams &params, const int16_t *input_data, const core::OMRuntimeShape &filter_shape, const int8_t *filter_data, const int64_t *bias_data, const core::OMRuntimeShape &output_shape, int16_t *output_data)