18#include "kernels/Utils.h"
35 if (
input()->element_type() == DataType::S16)
40 if (
input()->element_type() == DataType::U8 ||
input()->element_type() == DataType::S16)
50 switch (
input()->element_type())
52 case DataType::FLOAT32:
62 throw std::runtime_error(
"luci-intp Relu Unsupported type.");
66void Relu::evalFloat()
const
76void Relu::evalQuantized()
const
78 tflite::ReluParams params;
81 params.output_multiplier = _output_multiplier;
82 params.output_shift = _output_shift;
84 params.quantized_activation_min =
85 std::max(
static_cast<int32_t
>(std::numeric_limits<uint8_t>::min()), params.output_offset);
86 params.quantized_activation_max =
static_cast<int32_t
>(std::numeric_limits<uint8_t>::max());
92void Relu::evalQuantizedS16()
const
98 constexpr int32_t
output_max = std::numeric_limits<int16_t>::max();
102 for (int32_t
i = 0;
i < num_elements; ++
i)
106 tflite::MultiplyByQuantizedMultiplier(
input_val, _output_multiplier, _output_shift);
int32_t num_elements() const
void resize(const Shape &new_shape)
const Shape & shape() const
int32_t zero_point() const
Relu(const Tensor *input, Tensor *output)
void configure() override
const Tensor * input() const
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void quantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
T must_cast(loco::Node *node)