ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::cpu::ops::BinaryArithmeticLayer Class Reference

#include <BinaryArithmeticLayer.h>

Collaboration diagram for onert::backend::cpu::ops::BinaryArithmeticLayer:

Public Member Functions

 BinaryArithmeticLayer ()
 
void configure (const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output, const ir::Activation activation, const ArithmeticType arithmetic_type)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Protected Attributes

const IPortableTensor_lhs
 
const IPortableTensor_rhs
 
IPortableTensor_output
 
std::function< void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> _kernel
 

Detailed Description

Definition at line 42 of file BinaryArithmeticLayer.h.

Constructor & Destructor Documentation

◆ BinaryArithmeticLayer()

onert::backend::cpu::ops::BinaryArithmeticLayer::BinaryArithmeticLayer ( )
inline

Definition at line 45 of file BinaryArithmeticLayer.h.

45 : _lhs(nullptr), _rhs(nullptr), _output(nullptr)
46 {
47 // DO NOTHING
48 }

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::BinaryArithmeticLayer::configure ( const IPortableTensor lhs,
const IPortableTensor rhs,
IPortableTensor output,
const ir::Activation  activation,
const ArithmeticType  arithmetic_type 
)

Definition at line 186 of file BinaryArithmeticLayer.cc.

189{
190 assert(lhs != nullptr);
191 assert(rhs != nullptr);
192 assert(output != nullptr);
193
194 _lhs = lhs;
195 _rhs = rhs;
196 _output = output;
197
199 switch (arithmetic_type)
200 {
202 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
203 {
204 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
205 _kernel =
206 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, uint8_t>(_lhs, _rhs, _output, op_params);
207 }
208 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
209 {
210 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
211 _kernel =
212 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, int8_t>(_lhs, _rhs, _output, op_params);
213 }
214
215 else
216 {
217 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::ADD>(
218 _lhs, _rhs, _output, activation, op_params);
219 }
220 break;
222 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
223 {
224 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
225 op_params.input2_multiplier *= -1;
226 _kernel =
227 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, uint8_t>(_lhs, _rhs, _output, op_params);
228 }
229 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
230 {
231 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
232 op_params.input2_multiplier *= -1;
233 _kernel =
234 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, int8_t>(_lhs, _rhs, _output, op_params);
235 }
236
237 else
238 {
239 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::SUB>(
240 _lhs, _rhs, _output, activation, op_params);
241 }
242 break;
244 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
245 {
247 setMulQuant8Params(_lhs, _rhs, _output, activation, &op_params);
248 _kernel =
249 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, uint8_t>(_lhs, _rhs, _output, op_params);
250 }
251 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
252 {
254 setMulQuant8Params(_lhs, _rhs, _output, activation, &op_params);
255 _kernel =
256 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, int8_t>(_lhs, _rhs, _output, op_params);
257 }
258 else
259 {
260 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::MUL>(
261 _lhs, _rhs, _output, activation, op_params);
262 }
263 break;
265 if (_lhs->data_type() == OperandType::FLOAT32)
266 {
267 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::DIV>(
268 _lhs, _rhs, _output, activation, op_params);
269 }
270 else
271 {
272 // TODO Support quantized type
273 // TODO Support integer type with zero check
274 throw std::runtime_error{
275 "BinaryArithmetic(Div): Div operation does not support non-float data types yet"};
276 }
277 break;
278 default:
279 throw std::runtime_error{"BinaryArithmetic: Unsupported BinaryArithmetic type"};
280 }
281}
ir::DataType data_type() const override final
std::function< void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> _kernel

References _kernel, _lhs, _output, _rhs, onert::backend::IPortableTensor::data_type(), nnfw::cker::BinaryArithmeticOpParam::input2_multiplier, onert::backend::cpu::ops::kAdd, onert::backend::cpu::ops::kDiv, onert::backend::cpu::ops::kMul, and onert::backend::cpu::ops::kSub.

◆ run()

void onert::backend::cpu::ops::BinaryArithmeticLayer::run ( )
overridevirtual

Field Documentation

◆ _kernel

std::function<void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> onert::backend::cpu::ops::BinaryArithmeticLayer::_kernel
protected

Definition at line 61 of file BinaryArithmeticLayer.h.

Referenced by configure(), and run().

◆ _lhs

const IPortableTensor* onert::backend::cpu::ops::BinaryArithmeticLayer::_lhs
protected

◆ _output

IPortableTensor* onert::backend::cpu::ops::BinaryArithmeticLayer::_output
protected

◆ _rhs

const IPortableTensor* onert::backend::cpu::ops::BinaryArithmeticLayer::_rhs
protected

The documentation for this class was generated from the following files: