33template <nnfw::cker::BinaryArithmeticOpType arithmetic_type,
typename T>
struct Eval
41 Eval(
const IPortableTensor *lhs,
const IPortableTensor *rhs, IPortableTensor *output,
45 if (!output->is_dynamic())
46 updateCache(lhs, rhs, output);
49 void updateCache(
const IPortableTensor *lhs,
const IPortableTensor *rhs, IPortableTensor *output)
57 void operator()(
const IPortableTensor *lhs,
const IPortableTensor *rhs, IPortableTensor *output)
62 updateCache(lhs, rhs, output);
66 auto lhs_buffer = getBuffer<T>(lhs);
67 auto rhs_buffer = getBuffer<T>(rhs);
68 auto output_buffer = getBuffer<T>(output);
71 nnfw::cker::BroadcastBinaryArithmeticOp<arithmetic_type>(
72 _op_params, _lhs_shape, lhs_buffer, _rhs_shape, rhs_buffer, _output_shape, output_buffer);
76 nnfw::cker::BinaryArithmeticOp<arithmetic_type>(
77 _op_params, _lhs_shape, lhs_buffer, _rhs_shape, rhs_buffer, _output_shape, output_buffer);
82template <nnfw::cker::BinaryArithmeticOpType arithmetic_type>
83std::function<void(
const IPortableTensor *,
const IPortableTensor *, IPortableTensor *)>
84generateKernelGeneric(
const IPortableTensor *lhs,
const IPortableTensor *rhs,
88 switch (lhs->data_type())
90 case OperandType::FLOAT32:
92 float output_activation_min = 0, output_activation_max = 0;
96 return Eval<arithmetic_type, float>(lhs, rhs, output, op_params);
99 case OperandType::INT32:
101 int32_t output_activation_min = 0, output_activation_max = 0;
105 return Eval<arithmetic_type, int32_t>(lhs, rhs, output, op_params);
108 case OperandType::INT64:
110 int64_t output_activation_min = 0, output_activation_max = 0;
114 return Eval<arithmetic_type, int64_t>(lhs, rhs, output, op_params);
117 case OperandType::BOOL8:
120 throw std::runtime_error(
121 "BinaryArithmetic(generic): Fused activation is not supported with bool8 type");
122 int32_t output_activation_min = 0, output_activation_max = 0;
124 static_assert(
sizeof(bool) == 1,
"cpu backend supports bool type which is 1 byte");
125 return Eval<arithmetic_type, bool>(lhs, rhs, output, op_params);
129 throw std::runtime_error{
"BinaryArithmetic(generic): Unsupported data type"};
133void setAddOrSubQuant8Params(
const IPortableTensor *lhs,
const IPortableTensor *rhs,
137 int32_t output_activation_min, output_activation_max;
139 &output_activation_max);
152 const double norm_max_scale = 2 * std::max(lhs->data_scale(), rhs->data_scale());
153 const double real_lhs_scale = lhs->data_scale() / norm_max_scale;
154 const double real_rhs_scale = rhs->data_scale() / norm_max_scale;
156 const double real_output_scale =
165void setMulQuant8Params(
const IPortableTensor *lhs,
const IPortableTensor *rhs,
169 int32_t output_activation_min, output_activation_max;
171 &output_activation_max);
180 double real_multiplier = lhs->data_scale() * rhs->data_scale() /
output->data_scale();
190 assert(lhs !=
nullptr);
191 assert(rhs !=
nullptr);
192 assert(output !=
nullptr);
199 switch (arithmetic_type)
206 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, uint8_t>(
_lhs,
_rhs,
_output, op_params);
212 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, int8_t>(
_lhs,
_rhs,
_output, op_params);
217 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::ADD>(
227 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, uint8_t>(
_lhs,
_rhs,
_output, op_params);
234 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, int8_t>(
_lhs,
_rhs,
_output, op_params);
239 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::SUB>(
249 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, uint8_t>(
_lhs,
_rhs,
_output, op_params);
256 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, int8_t>(
_lhs,
_rhs,
_output, op_params);
260 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::MUL>(
267 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::DIV>(
274 throw std::runtime_error{
275 "BinaryArithmetic(Div): Div operation does not support non-float data types yet"};
279 throw std::runtime_error{
"BinaryArithmetic: Unsupported BinaryArithmetic type"};
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
A tensor class that is portable for other backends.
ir::DataType data_type() const override final
std::function< void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> _kernel
const IPortableTensor * _rhs
const IPortableTensor * _lhs
void configure(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output, const ir::Activation activation, const ArithmeticType arithmetic_type)
IPortableTensor * _output
nnfw::cker::Shape _output_shape
nnfw::cker::Shape _rhs_shape
nnfw::cker::BinaryArithmeticOpParam _op_params
nnfw::cker::Shape _lhs_shape
bool ProcessBroadcastShapes(const Shape &shape0, const Shape &shape1, BinaryArithmeticOpParam *params)
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
void CalculateActivationRangeQuantized(ir::Activation activation, const IPortableTensor *output, int32_t *act_min, int32_t *act_max)
void CalculateActivationRange(ir::Activation activation, T *activation_min, T *activation_max)
int32_t quantized_activation_max
int64_t int64_activation_min
int32_t input2_multiplier
int32_t quantized_activation_min
float float_activation_max
int32_t output_multiplier
int32_t input1_multiplier
int64_t int64_activation_max
float float_activation_min