18#include "kernels/Utils.h"
19#include "PALQuantize.h"
29template <
typename input_dtype>
void call_requantize(
const Tensor *input,
Tensor *output)
34 const double effective_output_scale =
input->scale() /
output->scale();
41 const auto input_data = getTensorData<input_dtype>(input);
43 switch (
output->element_type())
45 case loco::DataType::S8:
46 luci_interpreter_pal::Requantize(input_data,
size, multiplier, shift,
input->zero_point(),
47 output->zero_point(), getTensorData<int8_t>(output));
49 case loco::DataType::U8:
50 luci_interpreter_pal::Requantize(input_data,
size, multiplier, shift,
input->zero_point(),
51 output->zero_point(), getTensorData<uint8_t>(output));
53 case loco::DataType::S16:
54 luci_interpreter_pal::Requantize(input_data,
size, multiplier, shift,
input->zero_point(),
55 output->zero_point(), getTensorData<int16_t>(output));
58 throw std::runtime_error(
"Unsupported quantized type, yet!");
69 if (
input()->element_type() == loco::DataType::S16)
72 switch (
input()->element_type())
74 case loco::DataType::FLOAT32:
77 output()->element_type() == loco::DataType::S8 ||
78 output()->element_type() == loco::DataType::S16);
81 case loco::DataType::S16:
82 case loco::DataType::S8:
83 case loco::DataType::U8:
86 output()->element_type() == loco::DataType::U8 ||
87 output()->element_type() == loco::DataType::S16);
88 if (
output()->element_type() == loco::DataType::S16)
95 throw std::runtime_error(
"Unsupported type");
103 switch (
input()->element_type())
105 case loco::DataType::FLOAT32:
107 tflite::QuantizationParams op_params;
110 const auto input_data = getTensorData<float>(
input());
112 switch (
output()->element_type())
114 case loco::DataType::S8:
120 case loco::DataType::U8:
124 getTensorData<uint8_t>(
output()));
127 case loco::DataType::S16:
131 getTensorData<int16_t>(
output()));
135 throw std::runtime_error(
"luci-intp Quantize(1) Unsupported type.");
139 case loco::DataType::S16:
144 case loco::DataType::S8:
149 case loco::DataType::U8:
155 throw std::runtime_error(
"luci-intp Quantize(2) Unsupported type.");
void resize(const Shape &new_shape)
int32_t zero_point() const
const Tensor * input() const
void execute() const override
void configure() override
Quantize(const Tensor *input, Tensor *output)
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void quantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
Index shift(const Index &in_index, const Shape &shift_from)