ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::cpu::ops::BinaryArithmeticLayer Class Reference

#include <BinaryArithmeticLayer.h>

Collaboration diagram for onert::backend::cpu::ops::BinaryArithmeticLayer:

Public Member Functions

 BinaryArithmeticLayer ()
 
void configure (const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output, const ir::Activation activation, const ArithmeticType arithmetic_type)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Protected Attributes

const IPortableTensor_lhs
 
const IPortableTensor_rhs
 
IPortableTensor_output
 
std::function< void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> _kernel
 

Detailed Description

Definition at line 36 of file BinaryArithmeticLayer.h.

Constructor & Destructor Documentation

◆ BinaryArithmeticLayer()

onert::backend::cpu::ops::BinaryArithmeticLayer::BinaryArithmeticLayer ( )
inline

Definition at line 39 of file BinaryArithmeticLayer.h.

39 : _lhs(nullptr), _rhs(nullptr), _output(nullptr)
40 {
41 // DO NOTHING
42 }

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::BinaryArithmeticLayer::configure ( const IPortableTensor lhs,
const IPortableTensor rhs,
IPortableTensor output,
const ir::Activation  activation,
const ArithmeticType  arithmetic_type 
)

Definition at line 228 of file BinaryArithmeticLayer.cc.

231{
232 assert(lhs != nullptr);
233 assert(rhs != nullptr);
234 assert(output != nullptr);
235
236 _lhs = lhs;
237 _rhs = rhs;
238 _output = output;
239
241 switch (arithmetic_type)
242 {
244 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
245 {
246 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
247 _kernel =
248 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, uint8_t>(_lhs, _rhs, _output, op_params);
249 }
250 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
251 {
252 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
253 _kernel =
254 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, int8_t>(_lhs, _rhs, _output, op_params);
255 }
256
257 else
258 {
259 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::ADD>(
260 _lhs, _rhs, _output, activation, op_params);
261 }
262 break;
264 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
265 {
266 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
267 op_params.input2_multiplier *= -1;
268 _kernel =
269 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, uint8_t>(_lhs, _rhs, _output, op_params);
270 }
271 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
272 {
273 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
274 op_params.input2_multiplier *= -1;
275 _kernel =
276 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, int8_t>(_lhs, _rhs, _output, op_params);
277 }
278
279 else
280 {
281 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::SUB>(
282 _lhs, _rhs, _output, activation, op_params);
283 }
284 break;
286 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
287 {
289 setMulQuant8Params(_lhs, _rhs, _output, activation, &op_params);
290 _kernel =
291 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, uint8_t>(_lhs, _rhs, _output, op_params);
292 }
293 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
294 {
296 setMulQuant8Params(_lhs, _rhs, _output, activation, &op_params);
297 _kernel =
298 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, int8_t>(_lhs, _rhs, _output, op_params);
299 }
300 else
301 {
302 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::MUL>(
303 _lhs, _rhs, _output, activation, op_params);
304 }
305 break;
307 if (_lhs->data_type() == OperandType::FLOAT32)
308 {
309 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::DIV>(
310 _lhs, _rhs, _output, activation, op_params);
311 }
312 else
313 {
314 // TODO Support quantized type
315 // TODO Support integer type with zero check
316 throw std::runtime_error{
317 "BinaryArithmetic(Div): Div operation does not support non-float data types yet"};
318 }
319 break;
320 default:
321 throw std::runtime_error{"BinaryArithmetic: Unsupported BinaryArithmetic type"};
322 }
323}
ir::DataType data_type() const override final
std::function< void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> _kernel

References _kernel, _lhs, _output, _rhs, onert::backend::IPortableTensor::data_type(), nnfw::cker::BinaryArithmeticOpParam::input2_multiplier, onert::backend::cpu::ops::kAdd, onert::backend::cpu::ops::kDiv, onert::backend::cpu::ops::kMul, and onert::backend::cpu::ops::kSub.

◆ run()

void onert::backend::cpu::ops::BinaryArithmeticLayer::run ( )
overridevirtual

Field Documentation

◆ _kernel

std::function<void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> onert::backend::cpu::ops::BinaryArithmeticLayer::_kernel
protected

Definition at line 55 of file BinaryArithmeticLayer.h.

Referenced by configure(), and run().

◆ _lhs

const IPortableTensor* onert::backend::cpu::ops::BinaryArithmeticLayer::_lhs
protected

◆ _output

IPortableTensor* onert::backend::cpu::ops::BinaryArithmeticLayer::_output
protected

◆ _rhs

const IPortableTensor* onert::backend::cpu::ops::BinaryArithmeticLayer::_rhs
protected

The documentation for this class was generated from the following files: