ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::cpu::ops::BatchMatMulLayer Class Reference

#include <BatchMatMulLayer.h>

Collaboration diagram for onert::backend::cpu::ops::BatchMatMulLayer:

Public Member Functions

 BatchMatMulLayer ()
 
 ~BatchMatMulLayer ()
 
void batchMatMulFloat32 ()
 
void configure (const IPortableTensor *lhs, const IPortableTensor *rhs, bool adj_x, bool adj_y, IPortableTensor *output)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Detailed Description

Definition at line 42 of file BatchMatMulLayer.h.

Constructor & Destructor Documentation

◆ BatchMatMulLayer()

onert::backend::cpu::ops::BatchMatMulLayer::BatchMatMulLayer ( )

Definition at line 30 of file BatchMatMulLayer.cc.

31 : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false),
32 _kernel(new nnfw::cker::BatchMatMul())
33{
34 // DO NOTHING
35}

◆ ~BatchMatMulLayer()

onert::backend::cpu::ops::BatchMatMulLayer::~BatchMatMulLayer ( )
default

Member Function Documentation

◆ batchMatMulFloat32()

void onert::backend::cpu::ops::BatchMatMulLayer::batchMatMulFloat32 ( )

Definition at line 39 of file BatchMatMulLayer.cc.

40{
41 nnfw::cker::BatchMatMul &batchmatmul_kernel = *_kernel;
42 nnfw::cker::Shape lhs_shape = getShape(_lhs);
43 nnfw::cker::Shape rhs_shape = getShape(_rhs);
45
46 // TODO implement for constant input
47
48 batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y);
49 batchmatmul_kernel(lhs_shape, getBuffer<float>(_lhs), rhs_shape, getBuffer<float>(_rhs), _adj_x,
50 _adj_y, output_shape, getBuffer<float>(_output));
51}
void prepare(const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y)
Prepare temporary area for calculation.
Definition BatchMatMul.h:47
const luci_interpreter::RuntimeShape output_shape
nnfw::cker::Shape getShape(const IPortableTensor *tensor)

References onert::backend::cpu::ops::getShape(), output_shape, and nnfw::cker::BatchMatMul::prepare().

Referenced by run().

◆ configure()

void onert::backend::cpu::ops::BatchMatMulLayer::configure ( const IPortableTensor lhs,
const IPortableTensor rhs,
bool  adj_x,
bool  adj_y,
IPortableTensor output 
)

Definition at line 53 of file BatchMatMulLayer.cc.

55{
56 assert(lhs != nullptr);
57 assert(rhs != nullptr);
58 assert(output != nullptr);
59
60 _lhs = lhs;
61 _rhs = rhs;
62 _adj_x = adj_x;
63 _adj_y = adj_y;
64 _output = output;
65}

◆ run()

void onert::backend::cpu::ops::BatchMatMulLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 67 of file BatchMatMulLayer.cc.

68{
69 if ((_lhs->data_type() == OperandType::FLOAT32) && (_rhs->data_type() == OperandType::FLOAT32))
70 {
72 }
73 else
74 {
75 throw std::runtime_error{"BatchMatMul: unsupported data type"};
76 }
77}
ir::DataType data_type() const override final

References batchMatMulFloat32(), and onert::backend::IPortableTensor::data_type().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: