ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FullyConnected.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18
19#include "kernels/Utils.h"
20
21#include "PALFullyConnected.h"
22
23#include <stdexcept>
24
25namespace luci_interpreter
26{
27
28namespace kernels
29{
30
31FullyConnected::FullyConnected(const Tensor *input, const Tensor *weights, const Tensor *bias,
32 Tensor *output, const FullyConnectedParams &params)
34{
35}
36
38{
39 if (weights()->element_type() == DataType::U8)
40 {
41 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::U8);
42 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::U8);
43 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32)
44 }
45 else if (weights()->element_type() == DataType::FLOAT32)
46 {
47 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32);
48 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
49 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32)
50 }
51 else if (weights()->element_type() == DataType::S8)
52 {
53 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::S8);
54 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S8);
55 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32)
56 }
57 else if (weights()->element_type() == DataType::S4)
58 {
59 // TODO support other combinations when needed
60 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32);
61 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
62 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32)
63 }
64 else if (weights()->element_type() == DataType::U4)
65 {
66 // TODO support other combinations when needed
67 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32);
68 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
69 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32)
70 }
71 else
72 {
73 throw std::runtime_error("luci-intp FullyConnected(1) Unsupported type.");
74 }
75
76 const Shape &input_shape = input()->shape();
77 const Shape &weights_shape = weights()->shape();
78
79 LUCI_INTERPRETER_CHECK(weights_shape.num_dims() == 2);
80 LUCI_INTERPRETER_CHECK(bias() == nullptr ||
81 bias()->shape().num_elements() == weights_shape.dim(0));
82
83 LUCI_INTERPRETER_CHECK(input_shape.num_elements() % weights_shape.dim(1) == 0);
84 const int32_t batch_size = input_shape.num_elements() / weights_shape.dim(1);
85 const int32_t num_units = weights_shape.dim(0);
86
87 if (params().keep_num_dims == false)
88 {
89 output()->resize({batch_size, num_units});
90 }
91 else
92 {
94 for (int i = 0; i < input_shape.num_dims(); ++i)
95 output_shape.dim(i) = input_shape.dim(i);
96 output_shape.dim(input_shape.num_dims() - 1) = num_units;
98 }
99}
100
102{
103 const bool is_hybrid =
104 (input()->element_type() == DataType::FLOAT32 &&
105 (weights()->element_type() == DataType::S4 || weights()->element_type() == DataType::U4) &&
106 output()->element_type() == DataType::FLOAT32 &&
107 (!bias() || bias()->element_type() == DataType::FLOAT32));
108 if (is_hybrid)
109 {
110 switch (weights()->element_type())
111 {
112 case DataType::S4:
113 evalHybridWI4AF32();
114 break;
115 case DataType::U4:
116 evalHybridWU4AF32();
117 break;
118 default:
119 throw std::runtime_error("luci-intp FullyConnected(3) Unsupported type.");
120 }
121 }
122 else
123 {
124 switch (input()->element_type())
125 {
126 case DataType::U8:
127 evalQuantized();
128 break;
129 case DataType::S8:
130 evalQuantizedS8();
131 break;
132 case DataType::FLOAT32:
133 evalFloat();
134 break;
135 default:
136 throw std::runtime_error("luci-intp FullyConnected(2) Unsupported type.");
137 }
138 }
139}
140
141void FullyConnected::evalFloat() const
142{
143 float activation_min{};
144 float activation_max{};
145 calculateActivationRange(_params.activation, &activation_min, &activation_max);
146
147 tflite::FullyConnectedParams params{};
148 params.float_activation_min = activation_min;
149 params.float_activation_max = activation_max;
150 params.weights_format = tflite::FullyConnectedWeightsFormat::kDefault;
151
152 tflite::reference_ops::FullyConnected(
153 params, getTensorShape(input()), getTensorData<float>(input()), getTensorShape(weights()),
154 getTensorData<float>(weights()), getTensorShape(bias()), getTensorData<float>(bias()),
155 getTensorShape(output()), getTensorData<float>(output()));
156}
157
158void FullyConnected::evalQuantized() const
159{
160 double real_multiplier = 0.0;
161 int output_shift;
162 int32_t output_activation_min;
163 int32_t output_activation_max;
164 int32_t output_multiplier;
165 real_multiplier =
167 quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
168 calculateActivationRangeQuantized(params().activation, output(), &output_activation_min,
169 &output_activation_max);
170
171 int32_t input_offset = -input()->zero_point();
172 int32_t filter_offset = -weights()->zero_point();
173 int32_t output_offset = output()->zero_point();
174
175 tflite::FullyConnectedParams op_params{};
176 op_params.input_offset = input_offset;
177 op_params.weights_offset = filter_offset;
178 op_params.output_offset = output_offset;
179 op_params.output_multiplier = output_multiplier;
180 op_params.output_shift = output_shift;
181 op_params.quantized_activation_min = output_activation_min;
182 op_params.quantized_activation_max = output_activation_max;
183 op_params.lhs_cacheable = false;
184 op_params.rhs_cacheable = false;
185 tflite::reference_ops::FullyConnected(
186 op_params, getTensorShape(input()), getTensorData<uint8_t>(input()), getTensorShape(weights()),
187 getTensorData<uint8_t>(weights()), getTensorShape(bias()), getTensorData<int32_t>(bias()),
188 getTensorShape(output()), getTensorData<uint8_t>(output()));
189}
190
191void FullyConnected::evalQuantizedS8() const
192{
193 double real_multiplier = 0.0;
194 int output_shift;
195 int32_t output_activation_min;
196 int32_t output_activation_max;
197 int32_t output_multiplier;
198 real_multiplier =
200 quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
201 calculateActivationRangeQuantized(params().activation, output(), &output_activation_min,
202 &output_activation_max);
203
204 int32_t input_offset = -input()->zero_point();
205 int32_t filter_offset = -weights()->zero_point();
206 int32_t output_offset = output()->zero_point();
207
208 tflite::FullyConnectedParams op_params{};
209 op_params.input_offset = input_offset;
210 op_params.weights_offset = filter_offset;
211 op_params.output_offset = output_offset;
212 op_params.output_multiplier = output_multiplier;
213 op_params.output_shift = output_shift;
214 op_params.quantized_activation_min = output_activation_min;
215 op_params.quantized_activation_max = output_activation_max;
216 op_params.lhs_cacheable = false;
217 op_params.rhs_cacheable = false;
219 op_params, getTensorShape(input()), getTensorData<int8_t>(input()), getTensorShape(weights()),
220 getTensorData<int8_t>(weights()), getTensorShape(bias()), getTensorData<int32_t>(bias()),
221 getTensorShape(output()), getTensorData<int8_t>(output()));
222}
223
224void FullyConnected::evalHybridWI4AF32() const
225{
226 float activation_min{};
227 float activation_max{};
228 calculateActivationRange(_params.activation, &activation_min, &activation_max);
229
230 tflite::FullyConnectedParams params{};
231 params.float_activation_min = activation_min;
232 params.float_activation_max = activation_max;
233 params.weights_format = tflite::FullyConnectedWeightsFormat::kDefault;
234
235 const int8_t *weights_int4 = getTensorData<int8_t>(weights());
236 float *weights_float = getTensorData<float>(scratch());
237 const Shape &weights_shape = weights()->shape();
238 const auto weights_scales = weights()->scales();
239 const auto weights_quantized_dimension = weights()->quantized_dimension();
240 // Invariant for per-channel quantization of FC weights.
241 LUCI_INTERPRETER_CHECK(weights_quantized_dimension == 0);
242
243 if (weights_scales.size() == 1)
244 {
245 // Per tensor
246 const auto scale = weights()->scale();
247 for (int32_t i = 0; i < weights_shape.num_elements(); ++i)
248 {
249 weights_float[i] = scale * static_cast<float>(weights_int4[i]);
250 }
251 }
252 else
253 {
254 // Per channel
255 const int32_t quant_dim_size = weights_shape.dim(weights_quantized_dimension);
256
257 size_t outer_dims_size = 1;
258 size_t inner_dims_size = 1;
259 for (int i = 0; i < weights_quantized_dimension; ++i)
260 outer_dims_size *= weights_shape.dim(i);
261 for (int i = weights_quantized_dimension + 1; i < weights_shape.num_dims(); ++i)
262 inner_dims_size *= weights_shape.dim(i);
263
264 for (size_t outer_it = 0; outer_it < outer_dims_size; ++outer_it)
265 for (int32_t channel = 0; channel < quant_dim_size; ++channel)
266 {
267 float scale = weights_scales[channel];
268 size_t offset = inner_dims_size * (quant_dim_size * outer_it + channel);
269 for (size_t inner_it = 0; inner_it < inner_dims_size; ++inner_it)
270 {
271 LUCI_INTERPRETER_CHECK(offset + inner_it <
272 static_cast<size_t>(weights_shape.num_elements()));
273 weights_float[offset + inner_it] =
274 scale * static_cast<float>(weights_int4[offset + inner_it]);
275 }
276 }
277 }
278
279 tflite::reference_ops::FullyConnected(
280 params, getTensorShape(input()), getTensorData<float>(input()), getTensorShape(scratch()),
281 getTensorData<float>(scratch()), getTensorShape(bias()), getTensorData<float>(bias()),
282 getTensorShape(output()), getTensorData<float>(output()));
283}
284
285void FullyConnected::evalHybridWU4AF32() const
286{
287 float activation_min{};
288 float activation_max{};
289 calculateActivationRange(_params.activation, &activation_min, &activation_max);
290
291 tflite::FullyConnectedParams params{};
292 params.float_activation_min = activation_min;
293 params.float_activation_max = activation_max;
294 params.weights_format = tflite::FullyConnectedWeightsFormat::kDefault;
295
296 const auto *weights_uint4 = getTensorData<uint8_t>(weights());
297 auto *weights_float = getTensorData<float>(scratch());
298 const Shape &weights_shape = weights()->shape();
299 const auto weights_scales = weights()->scales();
300 const auto weights_zero_points = weights()->zero_points();
301 const auto weights_quantized_dimension = weights()->quantized_dimension();
302 LUCI_INTERPRETER_CHECK(weights_quantized_dimension == 0);
303 if (weights_scales.size() == 1)
304 {
305 // Per tensor
306 const auto scale = weights()->scale();
307 const auto zero_point = weights()->zero_point();
308 LUCI_INTERPRETER_CHECK(zero_point >= 0 and zero_point <= 15);
309 for (int32_t i = 0; i < weights_shape.num_elements(); ++i)
310 {
311 weights_float[i] =
312 scale * static_cast<float>(static_cast<int32_t>(weights_uint4[i]) - zero_point);
313 }
314 }
315 else
316 {
317 // Per channel
318 const int32_t quant_dim_size = weights_shape.dim(weights_quantized_dimension);
319
320 size_t outer_dims_size = 1;
321 size_t inner_dims_size = 1;
322 for (int i = 0; i < weights_quantized_dimension; ++i)
323 outer_dims_size *= weights_shape.dim(i);
324 for (int i = weights_quantized_dimension + 1; i < weights_shape.num_dims(); ++i)
325 inner_dims_size *= weights_shape.dim(i);
326
327 for (size_t outer_it = 0; outer_it < outer_dims_size; ++outer_it)
328 for (int32_t channel = 0; channel < quant_dim_size; ++channel)
329 {
330 int32_t zero_point = weights_zero_points[channel];
331 LUCI_INTERPRETER_CHECK(zero_point >= 0 and zero_point <= 15);
332 float scale = weights_scales[channel];
333 size_t offset = inner_dims_size * (quant_dim_size * outer_it + channel);
334 for (size_t inner_it = 0; inner_it < inner_dims_size; ++inner_it)
335 {
336 weights_float[offset + inner_it] =
337 scale *
338 static_cast<float>(static_cast<int32_t>(weights_uint4[offset + inner_it]) - zero_point);
339 }
340 }
341 }
342
343 tflite::reference_ops::FullyConnected(
344 params, getTensorShape(input()), getTensorData<float>(input()), getTensorShape(scratch()),
345 getTensorData<float>(scratch()), getTensorShape(bias()), getTensorData<float>(bias()),
346 getTensorShape(output()), getTensorData<float>(output()));
347}
348
349} // namespace kernels
350} // namespace luci_interpreter
const FullyConnectedParams & params() const
Definition Kernel.h:67
int32_t dim(int i) const
Definition Tensor.h:41
int32_t num_elements() const
Definition Tensor.h:53
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const std::vector< float > & scales() const
Definition Tensor.h:121
float scale() const
Definition Tensor.h:109
const std::vector< int32_t > & zero_points() const
Definition Tensor.h:123
DataType element_type() const
Definition Tensor.h:105
int32_t quantized_dimension() const
Definition Tensor.h:125
int32_t zero_point() const
Definition Tensor.h:115
FullyConnected(const Tensor *input, const Tensor *weights, const Tensor *bias, Tensor *output, const FullyConnectedParams &params)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
__global uchar * offset(const Image *img, int x, int y)
Definition helpers.h:540
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void calculateActivationRange(Activation activation, T *activation_min, T *activation_max)
Definition Utils.cpp:52
void calculateActivationRangeQuantized(Activation activation, const Tensor *output, int32_t *activation_min, int32_t *activation_max)
Definition Utils.cpp:119
double getQuantizedConvolutionMultipler(float input_scale, float filter_scale, float output_scale)
Definition Utils.h:137
void quantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
Definition Utils.cpp:157
void FullyConnected< int8_t >(const tflite::FullyConnectedParams &params, const tflite::RuntimeShape &input_shape, const int8_t *input_data, const tflite::RuntimeShape &filter_shape, const int8_t *filter_data, const tflite::RuntimeShape &bias_shape, const int32_t *bias_data, const tflite::RuntimeShape &output_shape, int8_t *output_data)
Definition Shape.h:28