18#include "kernels/Utils.h"
35 if (
input()->element_type() == DataType::S16)
40 if (
input()->element_type() == DataType::U8 ||
input()->element_type() == DataType::S16)
50 switch (
input()->element_type())
52 case DataType::FLOAT32:
62 throw std::runtime_error(
"luci-intp Relu Unsupported type.");
66void Relu::evalFloat()
const
68 const auto input_data = getTensorData<float>(
input());
70 auto output_data = getTensorData<float>(
output());
73 luci_interpreter_pal::Relu(input_shape, input_data,
output_shape, output_data);
76void Relu::evalQuantized()
const
78 tflite::ReluParams params;
81 params.output_multiplier = _output_multiplier;
82 params.output_shift = _output_shift;
84 params.quantized_activation_min =
85 std::max(
static_cast<int32_t
>(std::numeric_limits<uint8_t>::min()), params.output_offset);
86 params.quantized_activation_max =
static_cast<int32_t
>(std::numeric_limits<uint8_t>::max());
92void Relu::evalQuantizedS16()
const
94 const auto *input_data = getTensorData<int16_t>(
input());
95 auto *output_data = getTensorData<int16_t>(
output());
97 constexpr int32_t output_min = 0;
98 constexpr int32_t output_max = std::numeric_limits<int16_t>::max();
102 for (int32_t i = 0; i < num_elements; ++i)
104 const int32_t input_val = input_data[i];
106 tflite::MultiplyByQuantizedMultiplier(input_val, _output_multiplier, _output_shift);
107 output_val = std::max(output_val, output_min);
108 output_val = std::min(output_val, output_max);
109 output_data[i] =
static_cast<int16_t
>(output_val);
int32_t num_elements() const
void resize(const Shape &new_shape)
const Shape & shape() const
int32_t zero_point() const
Relu(const Tensor *input, Tensor *output)
void configure() override
const Tensor * input() const
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void quantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)