ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::cpu::ops::BinaryArithmeticLayer Class Reference

#include <BinaryArithmeticLayer.h>

Collaboration diagram for onert::backend::cpu::ops::BinaryArithmeticLayer:

Public Member Functions

 BinaryArithmeticLayer ()
 
void configure (const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output, const ir::Activation activation, const ArithmeticType arithmetic_type)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Protected Attributes

const IPortableTensor_lhs
 
const IPortableTensor_rhs
 
IPortableTensor_output
 
std::function< void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> _kernel
 

Detailed Description

Definition at line 36 of file BinaryArithmeticLayer.h.

Constructor & Destructor Documentation

◆ BinaryArithmeticLayer()

onert::backend::cpu::ops::BinaryArithmeticLayer::BinaryArithmeticLayer ( )
inline

Definition at line 39 of file BinaryArithmeticLayer.h.

39 : _lhs(nullptr), _rhs(nullptr), _output(nullptr)
40 {
41 // DO NOTHING
42 }

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::BinaryArithmeticLayer::configure ( const IPortableTensor lhs,
const IPortableTensor rhs,
IPortableTensor output,
const ir::Activation  activation,
const ArithmeticType  arithmetic_type 
)

Definition at line 180 of file BinaryArithmeticLayer.cc.

183{
184 assert(lhs != nullptr);
185 assert(rhs != nullptr);
186 assert(output != nullptr);
187
188 _lhs = lhs;
189 _rhs = rhs;
190 _output = output;
191
193 switch (arithmetic_type)
194 {
196 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
197 {
198 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
199 _kernel =
200 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, uint8_t>(_lhs, _rhs, _output, op_params);
201 }
202 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
203 {
204 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
205 _kernel =
206 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, int8_t>(_lhs, _rhs, _output, op_params);
207 }
208
209 else
210 {
211 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::ADD>(
212 _lhs, _rhs, _output, activation, op_params);
213 }
214 break;
216 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
217 {
218 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
219 op_params.input2_multiplier *= -1;
220 _kernel =
221 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, uint8_t>(_lhs, _rhs, _output, op_params);
222 }
223 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
224 {
225 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
226 op_params.input2_multiplier *= -1;
227 _kernel =
228 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, int8_t>(_lhs, _rhs, _output, op_params);
229 }
230
231 else
232 {
233 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::SUB>(
234 _lhs, _rhs, _output, activation, op_params);
235 }
236 break;
238 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
239 {
241 setMulQuant8Params(_lhs, _rhs, _output, activation, &op_params);
242 _kernel =
243 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, uint8_t>(_lhs, _rhs, _output, op_params);
244 }
245 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
246 {
248 setMulQuant8Params(_lhs, _rhs, _output, activation, &op_params);
249 _kernel =
250 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, int8_t>(_lhs, _rhs, _output, op_params);
251 }
252 else
253 {
254 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::MUL>(
255 _lhs, _rhs, _output, activation, op_params);
256 }
257 break;
259 if (_lhs->data_type() == OperandType::FLOAT32)
260 {
261 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::DIV>(
262 _lhs, _rhs, _output, activation, op_params);
263 }
264 else
265 {
266 // TODO Support quantized type
267 // TODO Support integer type with zero check
268 throw std::runtime_error{
269 "BinaryArithmetic(Div): Div operation does not support non-float data types yet"};
270 }
271 break;
272 default:
273 throw std::runtime_error{"BinaryArithmetic: Unsupported BinaryArithmetic type"};
274 }
275}
ir::DataType data_type() const override final
std::function< void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> _kernel

References _kernel, _lhs, _output, _rhs, onert::backend::IPortableTensor::data_type(), nnfw::cker::BinaryArithmeticOpParam::input2_multiplier, onert::backend::cpu::ops::kAdd, onert::backend::cpu::ops::kDiv, onert::backend::cpu::ops::kMul, and onert::backend::cpu::ops::kSub.

◆ run()

void onert::backend::cpu::ops::BinaryArithmeticLayer::run ( )
overridevirtual

Field Documentation

◆ _kernel

std::function<void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)> onert::backend::cpu::ops::BinaryArithmeticLayer::_kernel
protected

Definition at line 55 of file BinaryArithmeticLayer.h.

Referenced by configure(), and run().

◆ _lhs

const IPortableTensor* onert::backend::cpu::ops::BinaryArithmeticLayer::_lhs
protected

◆ _output

IPortableTensor* onert::backend::cpu::ops::BinaryArithmeticLayer::_output
protected

◆ _rhs

const IPortableTensor* onert::backend::cpu::ops::BinaryArithmeticLayer::_rhs
protected

The documentation for this class was generated from the following files: