ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Quantize.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "kernels/Quantize.h"
18#include "kernels/Utils.h"
19#include "PALQuantize.h"
20
21namespace luci_interpreter
22{
23namespace kernels
24{
25
26namespace
27{
28
29template <typename input_dtype> void call_requantize(const Tensor *input, Tensor *output)
30{
31 int32_t multiplier;
32 int shift;
33
34 const double effective_output_scale = input->scale() / output->scale();
35 quantizeMultiplier(effective_output_scale, &multiplier, &shift);
36
37 const auto input_shape = getTensorShape(input);
38 const auto output_shape = getTensorShape(output);
39 const auto size = tflite::MatchingFlatSize(input_shape, output_shape);
40
41 const auto input_data = getTensorData<input_dtype>(input);
42
43 switch (output->element_type())
44 {
45 case loco::DataType::S8:
46 luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
47 output->zero_point(), getTensorData<int8_t>(output));
48 break;
49 case loco::DataType::U8:
50 luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
51 output->zero_point(), getTensorData<uint8_t>(output));
52 break;
53 case loco::DataType::S16:
54 luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
55 output->zero_point(), getTensorData<int16_t>(output));
56 break;
57 default:
58 throw std::runtime_error("Unsupported quantized type, yet!");
59 }
60}
61
62} // namespace
63
64Quantize::Quantize(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
65
67{
68
69 if (input()->element_type() == loco::DataType::S16)
70 LUCI_INTERPRETER_CHECK(input()->zero_point() == 0);
71
72 switch (input()->element_type())
73 {
74 case loco::DataType::FLOAT32:
75 {
76 LUCI_INTERPRETER_CHECK(output()->element_type() == loco::DataType::U8 ||
77 output()->element_type() == loco::DataType::S8 ||
78 output()->element_type() == loco::DataType::S16);
79 break;
80 }
81 case loco::DataType::S16:
82 case loco::DataType::S8:
83 case loco::DataType::U8:
84 {
85 LUCI_INTERPRETER_CHECK(output()->element_type() == loco::DataType::S8 ||
86 output()->element_type() == loco::DataType::U8 ||
87 output()->element_type() == loco::DataType::S16);
88 if (output()->element_type() == loco::DataType::S16)
89 {
90 LUCI_INTERPRETER_CHECK(output()->zero_point() == 0);
91 }
92 break;
93 }
94 default:
95 throw std::runtime_error("Unsupported type");
96 }
97
98 output()->resize(input()->shape());
99}
100
102{
103 switch (input()->element_type())
104 {
105 case loco::DataType::FLOAT32:
106 {
107 tflite::QuantizationParams op_params;
108 op_params.zero_point = output()->zero_point();
109 op_params.scale = output()->scale();
110 const auto input_data = getTensorData<float>(input());
111
112 switch (output()->element_type())
113 {
114 case loco::DataType::S8:
115 {
116 luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
117 getTensorShape(output()), getTensorData<int8_t>(output()));
118 break;
119 }
120 case loco::DataType::U8:
121 {
122 luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
124 getTensorData<uint8_t>(output()));
125 break;
126 }
127 case loco::DataType::S16:
128 {
129 luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
131 getTensorData<int16_t>(output()));
132 break;
133 }
134 default:
135 throw std::runtime_error("luci-intp Quantize(1) Unsupported type.");
136 }
137 break;
138 }
139 case loco::DataType::S16:
140 {
141 call_requantize<int16_t>(input(), output());
142 break;
143 }
144 case loco::DataType::S8:
145 {
146 call_requantize<int8_t>(input(), output());
147 break;
148 }
149 case loco::DataType::U8:
150 {
151 call_requantize<uint8_t>(input(), output());
152 break;
153 }
154 default:
155 throw std::runtime_error("luci-intp Quantize(2) Unsupported type.");
156 }
157}
158
159} // namespace kernels
160} // namespace luci_interpreter
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
float scale() const
Definition Tensor.h:109
int32_t zero_point() const
Definition Tensor.h:115
const Tensor * input() const
Definition Quantize.h:33
void execute() const override
Definition Quantize.cpp:101
Quantize(const Tensor *input, Tensor *output)
Definition Quantize.cpp:64
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
list input_data
Definition infer.py:29
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void quantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
Definition Utils.cpp:157
Index shift(const Index &in_index, const Shape &shift_from)
Definition Common.cpp:26
int32_t size[5]
Definition Slice.cpp:35