19#include "kernels/BinaryOpCommon.h"
20#include "kernels/Utils.h"
22#include <tensorflow/lite/kernels/internal/reference/binary_function.h>
23#include <tensorflow/lite/kernels/internal/reference/prelu.h>
50 if (
input()->element_type() == DataType::U8)
53 _alpha_multipliers.resize(1);
56 &_alpha_multipliers[0].shift);
58 quantizeMultiplier(identity_multiplier, &_output_multiplier_identity, &_output_shift_identity);
60 else if (
input()->element_type() == DataType::S16)
71 alpha()->shape().dim(
alpha()->quantized_dimension()));
73 input()->shape().dim(
input()->shape().num_dims() - 1));
81 std::vector<double> real_multipliers =
87 quantizeMultiplier(identity_multiplier, &_output_multiplier_identity, &_output_shift_identity);
94 switch (
input()->element_type())
96 case DataType::FLOAT32:
106 throw std::runtime_error(
"luci-intp PRelu Unsupported type.");
110void PRelu::evalFloat()
const
112 const auto input_data = getTensorData<float>(
input());
113 const auto alpha_data = getTensorData<float>(
alpha());
115 auto output_data = getTensorData<float>(
output());
121 tflite::reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
128 for (
auto i =
decltype(
size){0}; i <
size; ++i)
130 if (input_data[i] >= 0)
138void PRelu::evalQuantized()
const
140 tflite::PreluParams op_params{};
145 op_params.output_shift_1 = _output_shift_identity;
146 op_params.output_multiplier_1 = _output_multiplier_identity;
147 op_params.output_shift_2 = _alpha_multipliers[0].shift;
148 op_params.output_multiplier_2 = _alpha_multipliers[0].multiplier;
152 tflite::reference_ops::BroadcastPrelu4DSlow(
158 tflite::reference_ops::Prelu<uint8_t>(
164static inline int16_t evalElemS16PRelu(int16_t input_val, int16_t alpha_val,
165 const ChannelQuantMultipliers &identity_mult,
166 const ChannelQuantMultipliers &alpha_mult)
168 constexpr int32_t quantized_min = std::numeric_limits<int16_t>::min();
169 constexpr int32_t quantized_max = std::numeric_limits<int16_t>::max();
171 const int32_t output_val =
173 ? tflite::MultiplyByQuantizedMultiplier(
static_cast<int32_t
>(input_val),
174 identity_mult.multiplier, identity_mult.shift)
176 alpha_mult.multiplier, alpha_mult.
shift);
177 const int32_t clamped_output = std::min(quantized_max, std::max(quantized_min, output_val));
178 return clamped_output;
181void PRelu::evalQuantizedS16()
const
186 const int16_t *alpha_data =
alpha()->
data<int16_t>();
189 const ChannelQuantMultipliers pos_mult{_output_shift_identity, _output_multiplier_identity};
193 int32_t outer_dims_size = 1;
194 for (
int i = 0; i < last_dim; ++i)
195 outer_dims_size *= input_shape.Dims(i);
196 int32_t quant_dim_size = input_shape.Dims(last_dim);
198 for (int32_t outer_dims = 0; outer_dims < outer_dims_size; ++outer_dims)
199 for (int32_t quant_channel = 0; quant_channel < quant_dim_size; ++quant_channel)
201 const ChannelQuantMultipliers &neg_mult = _alpha_multipliers[quant_channel];
202 size_t offset =
static_cast<size_t>(outer_dims) *
static_cast<size_t>(quant_dim_size);
206 evalElemS16PRelu(input_data[
offset], alpha_data[quant_channel], pos_mult, neg_mult);
void resize(const Shape &new_shape)
const Shape & shape() const
const std::vector< int32_t > & zero_points() const
int32_t zero_point() const
void execute() const override
const Tensor * alpha() const
const Tensor * input() const
void configure() override
PRelu(const Tensor *input, const Tensor *alpha, Tensor *output)
#define LUCI_INTERPRETER_CHECK(cond)
__global uchar * offset(const Image *img, int x, int y)
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
std::vector< ChannelQuantMultipliers > quantizeMultipliers(const std::vector< double > &effective_scale)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void quantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
std::vector< double > getQuantizedConvolutionMultiplers(float input_scale, const std::vector< float > &filter_scale, float output_scale)
Index shift(const Index &in_index, const Shape &shift_from)
int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)