27template <nnfw::cker::BinaryArithmeticOpType arithmetic_type,
typename T>
struct Eval
35 Eval(
const IPortableTensor *lhs,
const IPortableTensor *rhs, IPortableTensor *output,
39 if (!output->is_dynamic())
40 updateCache(lhs, rhs, output);
43 void updateCache(
const IPortableTensor *lhs,
const IPortableTensor *rhs, IPortableTensor *output)
51 void operator()(
const IPortableTensor *lhs,
const IPortableTensor *rhs, IPortableTensor *output)
56 updateCache(lhs, rhs, output);
60 auto lhs_buffer = getBuffer<T>(lhs);
61 auto rhs_buffer = getBuffer<T>(rhs);
62 auto output_buffer = getBuffer<T>(output);
65 nnfw::cker::BroadcastBinaryArithmeticOp<arithmetic_type>(
66 _op_params, _lhs_shape, lhs_buffer, _rhs_shape, rhs_buffer, _output_shape, output_buffer);
70 nnfw::cker::BinaryArithmeticOp<arithmetic_type>(
71 _op_params, _lhs_shape, lhs_buffer, _rhs_shape, rhs_buffer, _output_shape, output_buffer);
76template <nnfw::cker::BinaryArithmeticOpType arithmetic_type>
77std::function<void(
const IPortableTensor *,
const IPortableTensor *, IPortableTensor *)>
78generateKernelGeneric(
const IPortableTensor *lhs,
const IPortableTensor *rhs,
82 switch (lhs->data_type())
84 case OperandType::FLOAT32:
86 float output_activation_min = 0, output_activation_max = 0;
90 return Eval<arithmetic_type, float>(lhs, rhs, output, op_params);
93 case OperandType::INT32:
95 int32_t output_activation_min = 0, output_activation_max = 0;
99 return Eval<arithmetic_type, int32_t>(lhs, rhs, output, op_params);
102 case OperandType::INT64:
104 int64_t output_activation_min = 0, output_activation_max = 0;
108 return Eval<arithmetic_type, int64_t>(lhs, rhs, output, op_params);
111 case OperandType::BOOL8:
114 throw std::runtime_error(
115 "BinaryArithmetic(generic): Fused activation is not supported with bool8 type");
116 int32_t output_activation_min = 0, output_activation_max = 0;
118 static_assert(
sizeof(bool) == 1,
"cpu backend supports bool type which is 1 byte");
119 return Eval<arithmetic_type, bool>(lhs, rhs, output, op_params);
123 throw std::runtime_error{
"BinaryArithmetic(generic): Unsupported data type"};
127void setAddOrSubQuant8Params(
const IPortableTensor *lhs,
const IPortableTensor *rhs,
131 int32_t output_activation_min, output_activation_max;
133 &output_activation_max);
146 const double norm_max_scale = 2 * std::max(lhs->data_scale(), rhs->data_scale());
147 const double real_lhs_scale = lhs->data_scale() / norm_max_scale;
148 const double real_rhs_scale = rhs->data_scale() / norm_max_scale;
150 const double real_output_scale =
159void setMulQuant8Params(
const IPortableTensor *lhs,
const IPortableTensor *rhs,
163 int32_t output_activation_min, output_activation_max;
165 &output_activation_max);
174 double real_multiplier = lhs->data_scale() * rhs->data_scale() /
output->data_scale();
184 assert(lhs !=
nullptr);
185 assert(rhs !=
nullptr);
186 assert(output !=
nullptr);
193 switch (arithmetic_type)
200 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, uint8_t>(
_lhs,
_rhs,
_output, op_params);
206 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, int8_t>(
_lhs,
_rhs,
_output, op_params);
211 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::ADD>(
221 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, uint8_t>(
_lhs,
_rhs,
_output, op_params);
228 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, int8_t>(
_lhs,
_rhs,
_output, op_params);
233 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::SUB>(
243 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, uint8_t>(
_lhs,
_rhs,
_output, op_params);
250 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, int8_t>(
_lhs,
_rhs,
_output, op_params);
254 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::MUL>(
261 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::DIV>(
268 throw std::runtime_error{
269 "BinaryArithmetic(Div): Div operation does not support non-float data types yet"};
273 throw std::runtime_error{
"BinaryArithmetic: Unsupported BinaryArithmetic type"};
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
A tensor class that is portable for other backends.
ir::DataType data_type() const override final
std::function< void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> _kernel
const IPortableTensor * _rhs
const IPortableTensor * _lhs
void configure(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output, const ir::Activation activation, const ArithmeticType arithmetic_type)
IPortableTensor * _output
nnfw::cker::Shape _output_shape
nnfw::cker::Shape _rhs_shape
nnfw::cker::BinaryArithmeticOpParam _op_params
nnfw::cker::Shape _lhs_shape
bool ProcessBroadcastShapes(const Shape &shape0, const Shape &shape1, BinaryArithmeticOpParam *params)
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
void CalculateActivationRangeQuantized(ir::Activation activation, const IPortableTensor *output, int32_t *act_min, int32_t *act_max)
void CalculateActivationRange(ir::Activation activation, T *activation_min, T *activation_max)
int32_t quantized_activation_max
int64_t int64_activation_min
int32_t input2_multiplier
int32_t quantized_activation_min
float float_activation_max
int32_t output_multiplier
int32_t input1_multiplier
int64_t int64_activation_max
float float_activation_min