ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FullyConnected.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "FullyConnected.h"
19#include "Common.h"
20
21#include "QuantizationHelpers.h"
22
23#include "mir/Tensor.h"
24
25namespace mir_interpreter
26{
27
28template <typename T>
29static void fullyConnected2D(const mir::TensorVariant &input, const mir::TensorVariant &weights,
30 mir::TensorVariant &output)
31{
32 assert(input.getShape().rank() == 2);
33 assert(weights.getShape().rank() == 2);
34 assert(input.getShape().dim(1) == weights.getShape().dim(0));
35
36 auto in_raw = reinterpret_cast<T *>(input.atOffset(0));
37 auto weight_raw = reinterpret_cast<T *>(weights.atOffset(0));
38 auto output_raw = reinterpret_cast<T *>(output.atOffset(0));
39
40 auto rows = output.getShape().dim(0);
41 auto cols = output.getShape().dim(1);
42 auto N = input.getShape().dim(1);
43 auto wcols = weights.getShape().dim(1);
44
45 for (int32_t r = 0; r < rows; ++r)
46 {
47 for (int32_t k = 0; k < N; ++k)
48 {
49 auto in = in_raw[r * N + k];
50
51 for (int32_t c = 0; c < cols; ++c)
52 {
53 output_raw[r * cols + c] += in * weight_raw[k * wcols + c];
54 }
55 }
56 }
57}
58
59template <typename T> struct FullyConnectedImpl
60{
61 static void run(const mir::TensorVariant &inputv, const mir::TensorVariant &weightsv,
63 const mir::TensorVariant *biasv);
64};
65
66template <typename T>
68 const mir::TensorVariant &weightsv,
70 const mir::TensorVariant *biasv)
71{
72 if (biasv)
73 {
74 throw std::runtime_error("non-quantized FullyConnected with fused bias is unsupported");
75 }
76
77 mir::Tensor<T> input{inputv};
78 mir::Tensor<T> weights{weightsv};
79
80 erase<T>(res);
81
82 if (input.getShape().rank() == 2 && weights.getShape().rank() == 2 && res.getShape().rank() == 2)
83 {
84 // optimized case for 2d matrix multiplication
85 fullyConnected2D<T>(inputv, weightsv, res);
86 return;
87 }
88
89 mir::Tensor<T> accessor(res);
90
91 const mir::Shape &in_shape = input.getShape();
92 int32_t in_rank = in_shape.rank();
93
94 const mir::Shape &w_shape = weights.getShape();
95 int32_t w_rank = w_shape.rank();
96
97 assert(in_shape.dim(in_rank - 1) == w_shape.dim(w_rank - 2));
98 (void)in_rank;
99
100 mir::ShapeRange out_range(res.getShape());
101
102 int32_t len = w_shape.dim(w_rank - 2);
103
104 for (auto &out_index : out_range)
105 {
106 mir::Index t_index = out_index;
107 T &output_element = accessor.at(out_index);
108 int32_t col = t_index.at(w_rank - 1);
109 int32_t row = t_index.at(w_rank - 2);
110 for (int32_t i = 0; i < len; ++i)
111 {
112 t_index.at(w_rank - 1) = i;
113 T in = input.at(t_index);
114 t_index.at(w_rank - 1) = col;
115 t_index.at(w_rank - 2) = i;
116 T w = weights.at(t_index);
117 t_index.at(w_rank - 2) = row;
118 output_element += in * w;
119 }
120 }
121}
122
123template <> struct FullyConnectedImpl<uint8_t>
124{
125 static void run(const mir::TensorVariant &inputv, const mir::TensorVariant &weightsv,
127 const mir::TensorVariant *biasv);
128};
129
131 const mir::TensorVariant &weightsv,
133 const mir::TensorVariant *biasv)
134{
135 if (!biasv)
136 {
137 throw std::runtime_error{"Quantized FullyConnected cannot be executed without fused bias"};
138 }
139
140 const auto &input_type = inputv.getType();
141 const auto &weights_type = weightsv.getType();
142 const auto &bias_type = biasv->getType();
143 const auto &output_type = op.getOutput(0)->getType();
144 (void)bias_type;
145
146 assert(input_type.isQuantized());
147 assert(weights_type.isQuantized());
148 assert(bias_type.isQuantized());
149 assert(output_type.isQuantized());
150 assert(input_type.getElementType() == mir::DataType::UINT8);
151 assert(weights_type.getElementType() == mir::DataType::UINT8);
152 assert(bias_type.getElementType() == mir::DataType::INT32);
153
154 int32_t input_offset = -input_type.getQuantization().getZeroPoint();
155 int32_t weights_offset = -weights_type.getQuantization().getZeroPoint();
156 int32_t output_offset = output_type.getQuantization().getZeroPoint();
157
158 double input_scale = input_type.getQuantization().getScale();
159 double weights_scale = weights_type.getQuantization().getScale();
160 double output_scale = output_type.getQuantization().getScale();
161
162 double real_multiplier = input_scale * weights_scale / output_scale;
163 int32_t output_multiplier = 0;
164 int output_shift = 0;
165 QuantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
166
167 const mir::Shape &in_shape = inputv.getShape();
168 const mir::Shape &weights_shape = weightsv.getShape();
169 const mir::Shape &out_shape = op.getOutputShape(0);
170
171 const int32_t batches = in_shape.dim(0);
172 assert(in_shape.rank() == 2);
173 assert(weights_shape.rank() == 2);
174 assert(in_shape.dim(1) == weights_shape.dim(0));
175 const int32_t accum_depth = weights_shape.dim(0);
176 const int32_t output_depth = weights_shape.dim(1);
177
178 uint8_t *input_data = reinterpret_cast<uint8_t *>(inputv.atOffset(0));
179 uint8_t *weights_data = reinterpret_cast<uint8_t *>(weightsv.atOffset(0));
180 int32_t *bias_data = reinterpret_cast<int32_t *>(biasv->atOffset(0));
181
182 uint8_t *output_data = reinterpret_cast<uint8_t *>(res.atOffset(0));
183
184 int32_t output_min = std::numeric_limits<uint8_t>::min();
185 int32_t output_max = std::numeric_limits<uint8_t>::max();
186
187 for (int32_t b = 0; b < batches; ++b)
188 {
189 for (int32_t out_c = 0; out_c < output_depth; ++out_c)
190 {
191 int32_t acc = 0;
192 for (int d = 0; d < accum_depth; ++d)
193 {
194 int32_t input_val = input_data[b * accum_depth + d];
195 int32_t weights_val = weights_data[d * output_depth + out_c];
196 acc += (weights_val + weights_offset) * (input_val + input_offset);
197 }
198 acc += bias_data[out_c];
199 acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
200 acc += output_offset;
201 acc = std::max(acc, output_min);
202 acc = std::min(acc, output_max);
203 output_data[out_c + output_depth * b] = static_cast<uint8_t>(acc);
204 }
205 }
206}
207
208void FullyConnected(const mir::TensorVariant &input, const mir::TensorVariant &weights,
210 const mir::TensorVariant *bias)
211{
212 dispatch<FullyConnectedImpl>(res.getElementType(), input, weights, op, res, bias);
213}
214} // namespace mir_interpreter
int32_t & at(int32_t axis)
return position on given axis
Definition Index.h:64
const TensorType & getType() const
Gets the type of this output.
Definition Operation.h:91
Output * getOutput(std::size_t index)
Definition Operation.h:149
const Shape & getOutputShape(std::size_t index) const
Definition Operation.h:163
int32_t & dim(int32_t axis) noexcept
Definition Shape.h:47
int32_t rank() const
Definition Shape.h:43
T at(const Index &id) const
Definition Tensor.h:31
char * atOffset(int32_t offset) const
const TensorType & getType() const
const Shape & getShape() const
DataType getElementType() const
char * at(const Index &idx) const
int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
void FullyConnected(const mir::TensorVariant &input, const mir::TensorVariant &weights, const mir::ops::FullyConnectedOp &op, mir::TensorVariant &res, const mir::TensorVariant *bias)
void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
static void run(const mir::TensorVariant &inputv, const mir::TensorVariant &weightsv, const mir::ops::FullyConnectedOp &op, mir::TensorVariant &res, const mir::TensorVariant *biasv)