20#include "kernels/BinaryOpCommon.h"
21#include "kernels/Utils.h"
25#include <tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h>
43 if (
input1()->element_type() == DataType::S16)
48 output()->zero_point() == 0);
56 switch (
input1()->element_type())
58 case DataType::FLOAT32:
71 throw std::runtime_error(
"luci-intp Mul Unsupported type.");
75void Mul::evalFloat()
const
77 tflite::ArithmeticParams
params{};
80 const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
85 luci_interpreter_pal::BroadcastMul4DSlow(
97template <
typename T>
void Mul::evalInteger()
const
99 tflite::ArithmeticParams
params{};
102 const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
107 luci_interpreter_pal::BroadcastMul4DSlow(
119void Mul::evalQuantizedS16()
const
123 const auto output_scale =
static_cast<double>(
output()->
scale());
127 int32_t output_multiplier;
131 int32_t activation_min{};
132 int32_t activation_max{};
135 auto fn = [output_multiplier, output_shift, activation_min, activation_max](
int16_t input1_val,
138 output = tflite::MultiplyByQuantizedMultiplier(
output, output_multiplier, output_shift);
const MulParams & params() const
void resize(const Shape &new_shape)
void configure() override
const Tensor * input1() const
void execute() const override
const Tensor * input2() const
Mul(const Tensor *input1, const Tensor *input2, Tensor *output, const MulParams ¶ms)
#define LUCI_INTERPRETER_CHECK(cond)
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void calculateActivationRangeQuantized(Activation activation, const Tensor *output, int32_t *activation_min, int32_t *activation_max)
void quantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
void BinaryOpBroadcastSlow(const tflite::RuntimeShape &unextended_input1_shape, const T *input1_data, const tflite::RuntimeShape &unextended_input2_shape, const T *input2_data, const tflite::RuntimeShape &unextended_output_shape, T *output_data, Op op)
T must_cast(loco::Node *node)