ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::cpu::ops::BatchMatMulLayer Class Reference

#include <BatchMatMulLayer.h>

Collaboration diagram for onert::backend::cpu::ops::BatchMatMulLayer:

Public Member Functions

 BatchMatMulLayer ()
 
 ~BatchMatMulLayer ()
 
void batchMatMulFloat32 ()
 
void configure (const IPortableTensor *lhs, const IPortableTensor *rhs, bool adj_x, bool adj_y, IPortableTensor *output)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Detailed Description

Definition at line 33 of file BatchMatMulLayer.h.

Constructor & Destructor Documentation

◆ BatchMatMulLayer()

onert::backend::cpu::ops::BatchMatMulLayer::BatchMatMulLayer ( )

Definition at line 24 of file BatchMatMulLayer.cc.

25 : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false),
26 _kernel(new nnfw::cker::BatchMatMul())
27{
28 // DO NOTHING
29}

◆ ~BatchMatMulLayer()

onert::backend::cpu::ops::BatchMatMulLayer::~BatchMatMulLayer ( )
default

Member Function Documentation

◆ batchMatMulFloat32()

void onert::backend::cpu::ops::BatchMatMulLayer::batchMatMulFloat32 ( )

Definition at line 33 of file BatchMatMulLayer.cc.

34{
35 nnfw::cker::BatchMatMul &batchmatmul_kernel = *_kernel;
36 nnfw::cker::Shape lhs_shape = getShape(_lhs);
37 nnfw::cker::Shape rhs_shape = getShape(_rhs);
39
40 // TODO implement for constant input
41
42 batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y, _rhs->is_constant());
43 batchmatmul_kernel(lhs_shape, getBuffer<float>(_lhs), rhs_shape, getBuffer<float>(_rhs), _adj_x,
44 _adj_y, output_shape, getBuffer<float>(_output));
45}
void prepare(const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y, bool rhs_const)
Prepare temporary area for calculation.
Definition BatchMatMul.h:47
bool is_constant() const override final
Return true if the tensor is constant.
const luci_interpreter::RuntimeShape output_shape
nnfw::cker::Shape getShape(const IPortableTensor *tensor)

References onert::backend::cpu::ops::getShape(), onert::backend::IPortableTensor::is_constant(), output_shape, and nnfw::cker::BatchMatMul::prepare().

Referenced by run().

◆ configure()

void onert::backend::cpu::ops::BatchMatMulLayer::configure ( const IPortableTensor lhs,
const IPortableTensor rhs,
bool  adj_x,
bool  adj_y,
IPortableTensor output 
)

Definition at line 47 of file BatchMatMulLayer.cc.

49{
50 assert(lhs != nullptr);
51 assert(rhs != nullptr);
52 assert(output != nullptr);
53
54 _lhs = lhs;
55 _rhs = rhs;
56 _adj_x = adj_x;
57 _adj_y = adj_y;
58 _output = output;
59}

◆ run()

void onert::backend::cpu::ops::BatchMatMulLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 61 of file BatchMatMulLayer.cc.

62{
63 if ((_lhs->data_type() == OperandType::FLOAT32) && (_rhs->data_type() == OperandType::FLOAT32))
64 {
66 }
67 else
68 {
69 throw std::runtime_error{"BatchMatMul: unsupported data type"};
70 }
71}
ir::DataType data_type() const override final

References batchMatMulFloat32(), and onert::backend::IPortableTensor::data_type().


The documentation for this class was generated from the following files: