74 throw std::runtime_error(
"non-quantized FullyConnected with fused bias is unsupported");
85 fullyConnected2D<T>(inputv, weightsv, res);
92 int32_t in_rank = in_shape.
rank();
95 int32_t w_rank = w_shape.
rank();
97 assert(in_shape.
dim(in_rank - 1) == w_shape.
dim(w_rank - 2));
102 int32_t len = w_shape.
dim(w_rank - 2);
104 for (
auto &out_index : out_range)
107 T &output_element = accessor.
at(out_index);
108 int32_t col = t_index.
at(w_rank - 1);
109 int32_t row = t_index.
at(w_rank - 2);
110 for (int32_t i = 0; i < len; ++i)
112 t_index.
at(w_rank - 1) = i;
113 T in = input.at(t_index);
114 t_index.
at(w_rank - 1) = col;
115 t_index.
at(w_rank - 2) = i;
116 T w = weights.
at(t_index);
117 t_index.
at(w_rank - 2) = row;
118 output_element += in * w;
137 throw std::runtime_error{
"Quantized FullyConnected cannot be executed without fused bias"};
140 const auto &input_type = inputv.
getType();
141 const auto &weights_type = weightsv.
getType();
142 const auto &bias_type = biasv->
getType();
146 assert(input_type.isQuantized());
147 assert(weights_type.isQuantized());
148 assert(bias_type.isQuantized());
149 assert(output_type.isQuantized());
150 assert(input_type.getElementType() == mir::DataType::UINT8);
151 assert(weights_type.getElementType() == mir::DataType::UINT8);
152 assert(bias_type.getElementType() == mir::DataType::INT32);
154 int32_t input_offset = -input_type.getQuantization().getZeroPoint();
155 int32_t weights_offset = -weights_type.getQuantization().getZeroPoint();
156 int32_t output_offset = output_type.getQuantization().getZeroPoint();
158 double input_scale = input_type.getQuantization().getScale();
159 double weights_scale = weights_type.getQuantization().getScale();
160 double output_scale = output_type.getQuantization().getScale();
162 double real_multiplier = input_scale * weights_scale / output_scale;
163 int32_t output_multiplier = 0;
164 int output_shift = 0;
171 const int32_t batches = in_shape.
dim(0);
172 assert(in_shape.
rank() == 2);
173 assert(weights_shape.
rank() == 2);
174 assert(in_shape.
dim(1) == weights_shape.
dim(0));
175 const int32_t accum_depth = weights_shape.
dim(0);
176 const int32_t output_depth = weights_shape.
dim(1);
178 uint8_t *input_data =
reinterpret_cast<uint8_t *
>(inputv.
atOffset(0));
179 uint8_t *weights_data =
reinterpret_cast<uint8_t *
>(weightsv.
atOffset(0));
180 int32_t *bias_data =
reinterpret_cast<int32_t *
>(biasv->
atOffset(0));
182 uint8_t *output_data =
reinterpret_cast<uint8_t *
>(res.
atOffset(0));
184 int32_t output_min = std::numeric_limits<uint8_t>::min();
185 int32_t output_max = std::numeric_limits<uint8_t>::max();
187 for (int32_t b = 0; b < batches; ++b)
189 for (int32_t out_c = 0; out_c < output_depth; ++out_c)
192 for (
int d = 0; d < accum_depth; ++d)
194 int32_t input_val = input_data[b * accum_depth + d];
195 int32_t weights_val = weights_data[d * output_depth + out_c];
196 acc += (weights_val + weights_offset) * (input_val + input_offset);
198 acc += bias_data[out_c];
200 acc += output_offset;
201 acc = std::max(acc, output_min);
202 acc = std::min(acc, output_max);
203 output_data[out_c + output_depth * b] =
static_cast<uint8_t
>(acc);
void FullyConnected(const mir::TensorVariant &input, const mir::TensorVariant &weights, const mir::ops::FullyConnectedOp &op, mir::TensorVariant &res, const mir::TensorVariant *bias)