19#include "../KernelGenerator.h"
20#include "../Validator.h"
27void Validator::visit(
const ir::operation::BinaryArithmetic &) {
_supported =
true; }
32 switch (arithmetic_type_ir)
43 throw std::runtime_error(
"cpu KernelGenerator : Not supported operation yet");
55 auto ofm_tensor = _tensor_reg->getPortableTensor(ofm_index);
56 auto lhs_tensor = _tensor_reg->getPortableTensor(lhs_index);
57 auto rhs_tensor = _tensor_reg->getPortableTensor(rhs_index);
59 auto fn = std::make_unique<ops::BinaryArithmeticLayer>();
61 fn->configure(lhs_tensor, rhs_tensor, ofm_tensor, activation,
75template <nnfw::cker::BinaryArithmeticOpType arithmetic_type,
typename T>
struct Eval
83 Eval(
const IPortableTensor *lhs,
const IPortableTensor *rhs, IPortableTensor *output,
87 if (!output->is_dynamic())
88 updateCache(lhs, rhs, output);
91 void updateCache(
const IPortableTensor *lhs,
const IPortableTensor *rhs, IPortableTensor *output)
99 void operator()(
const IPortableTensor *lhs,
const IPortableTensor *rhs, IPortableTensor *output)
104 updateCache(lhs, rhs, output);
108 auto lhs_buffer = getBuffer<T>(lhs);
109 auto rhs_buffer = getBuffer<T>(rhs);
110 auto output_buffer = getBuffer<T>(output);
113 nnfw::cker::BroadcastBinaryArithmeticOp<arithmetic_type>(
114 _op_params, _lhs_shape, lhs_buffer, _rhs_shape, rhs_buffer, _output_shape, output_buffer);
118 nnfw::cker::BinaryArithmeticOp<arithmetic_type>(
119 _op_params, _lhs_shape, lhs_buffer, _rhs_shape, rhs_buffer, _output_shape, output_buffer);
124template <nnfw::cker::BinaryArithmeticOpType arithmetic_type>
125std::function<void(
const IPortableTensor *,
const IPortableTensor *, IPortableTensor *)>
126generateKernelGeneric(
const IPortableTensor *lhs,
const IPortableTensor *rhs,
130 switch (lhs->data_type())
132 case OperandType::FLOAT32:
134 float output_activation_min = 0, output_activation_max = 0;
138 return Eval<arithmetic_type, float>(lhs, rhs, output, op_params);
141 case OperandType::INT32:
143 int32_t output_activation_min = 0, output_activation_max = 0;
147 return Eval<arithmetic_type, int32_t>(lhs, rhs, output, op_params);
150 case OperandType::INT64:
152 int64_t output_activation_min = 0, output_activation_max = 0;
156 return Eval<arithmetic_type, int64_t>(lhs, rhs, output, op_params);
159 case OperandType::BOOL8:
162 throw std::runtime_error(
163 "BinaryArithmetic(generic): Fused activation is not supported with bool8 type");
164 int32_t output_activation_min = 0, output_activation_max = 0;
166 static_assert(
sizeof(bool) == 1,
"cpu backend supports bool type which is 1 byte");
167 return Eval<arithmetic_type, bool>(lhs, rhs, output, op_params);
171 throw std::runtime_error{
"BinaryArithmetic(generic): Unsupported data type"};
175void setAddOrSubQuant8Params(
const IPortableTensor *lhs,
const IPortableTensor *rhs,
179 int32_t output_activation_min, output_activation_max;
181 &output_activation_max);
194 const double norm_max_scale = 2 * std::max(lhs->data_scale(), rhs->data_scale());
195 const double real_lhs_scale = lhs->data_scale() / norm_max_scale;
196 const double real_rhs_scale = rhs->data_scale() / norm_max_scale;
198 const double real_output_scale =
207void setMulQuant8Params(
const IPortableTensor *lhs,
const IPortableTensor *rhs,
211 int32_t output_activation_min, output_activation_max;
213 &output_activation_max);
222 double real_multiplier = lhs->data_scale() * rhs->data_scale() /
output->data_scale();
232 assert(lhs !=
nullptr);
233 assert(rhs !=
nullptr);
234 assert(output !=
nullptr);
241 switch (arithmetic_type)
248 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, uint8_t>(
_lhs,
_rhs,
_output, op_params);
254 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, int8_t>(
_lhs,
_rhs,
_output, op_params);
259 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::ADD>(
269 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, uint8_t>(
_lhs,
_rhs,
_output, op_params);
276 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, int8_t>(
_lhs,
_rhs,
_output, op_params);
281 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::SUB>(
291 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, uint8_t>(
_lhs,
_rhs,
_output, op_params);
298 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, int8_t>(
_lhs,
_rhs,
_output, op_params);
302 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::MUL>(
309 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::DIV>(
316 throw std::runtime_error{
317 "BinaryArithmetic(Div): Div operation does not support non-float data types yet"};
321 throw std::runtime_error{
"BinaryArithmetic: Unsupported BinaryArithmetic type"};
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
A tensor class that is portable for other backends.
ir::DataType data_type() const override final
std::unique_ptr< exec::IFunction > _return_fn
std::function< void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> _kernel
const IPortableTensor * _rhs
const IPortableTensor * _lhs
void configure(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output, const ir::Activation activation, const ArithmeticType arithmetic_type)
IPortableTensor * _output
const OperandIndex & at(IOIndex set_index) const
const OperandIndexSequence & getOutputs() const override
OperandIndexSequence & getInputs()
const Param & param() const
nnfw::cker::Shape _output_shape
nnfw::cker::Shape _rhs_shape
nnfw::cker::BinaryArithmeticOpParam _op_params
nnfw::cker::Shape _lhs_shape
bool ProcessBroadcastShapes(const Shape &shape0, const Shape &shape1, BinaryArithmeticOpParam *params)
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
void CalculateActivationRangeQuantized(ir::Activation activation, const IPortableTensor *output, int32_t *act_min, int32_t *act_max)
ops::ArithmeticType convertArithmeticType(ir::operation::BinaryArithmetic::ArithmeticType arithmetic_type_ir)
void CalculateActivationRange(ir::Activation activation, T *activation_min, T *activation_max)
int32_t quantized_activation_max
int64_t int64_activation_min
int32_t input2_multiplier
int32_t quantized_activation_min
float float_activation_max
int32_t output_multiplier
int32_t input1_multiplier
int64_t int64_activation_max
float float_activation_min
ArithmeticType arithmetic_type