ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::cpu::ops::BatchMatMulLayer Class Reference

#include <BatchMatMulLayer.h>

Collaboration diagram for onert::backend::cpu::ops::BatchMatMulLayer:

Public Member Functions

 BatchMatMulLayer ()
 
 ~BatchMatMulLayer ()
 
void batchMatMulFloat32 ()
 
void configure (const IPortableTensor *lhs, const IPortableTensor *rhs, bool adj_x, bool adj_y, IPortableTensor *output)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Detailed Description

Definition at line 33 of file BatchMatMulLayer.h.

Constructor & Destructor Documentation

◆ BatchMatMulLayer()

onert::backend::cpu::ops::BatchMatMulLayer::BatchMatMulLayer ( )

Definition at line 53 of file BatchMatMulLayer.cc.

54 : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false),
55 _kernel(new nnfw::cker::BatchMatMul())
56{
57 // DO NOTHING
58}

◆ ~BatchMatMulLayer()

onert::backend::cpu::ops::BatchMatMulLayer::~BatchMatMulLayer ( )
default

Member Function Documentation

◆ batchMatMulFloat32()

void onert::backend::cpu::ops::BatchMatMulLayer::batchMatMulFloat32 ( )

Definition at line 62 of file BatchMatMulLayer.cc.

63{
64 nnfw::cker::BatchMatMul &batchmatmul_kernel = *_kernel;
65 nnfw::cker::Shape lhs_shape = getShape(_lhs);
66 nnfw::cker::Shape rhs_shape = getShape(_rhs);
68
69 // TODO implement for constant input
70
71 batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y, _rhs->is_constant());
72 batchmatmul_kernel(lhs_shape, getBuffer<float>(_lhs), rhs_shape, getBuffer<float>(_rhs), _adj_x,
73 _adj_y, output_shape, getBuffer<float>(_output));
74}
void prepare(const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y, bool rhs_const)
Prepare temporary area for calculation.
Definition BatchMatMul.h:47
bool is_constant() const override final
Return true if the tensor is constant.
const luci_interpreter::RuntimeShape output_shape
nnfw::cker::Shape getShape(const IPortableTensor *tensor)

References onert::backend::cpu::ops::getShape(), onert::backend::IPortableTensor::is_constant(), output_shape, and nnfw::cker::BatchMatMul::prepare().

Referenced by run().

◆ configure()

void onert::backend::cpu::ops::BatchMatMulLayer::configure ( const IPortableTensor lhs,
const IPortableTensor rhs,
bool  adj_x,
bool  adj_y,
IPortableTensor output 
)

Definition at line 76 of file BatchMatMulLayer.cc.

78{
79 assert(lhs != nullptr);
80 assert(rhs != nullptr);
81 assert(output != nullptr);
82
83 _lhs = lhs;
84 _rhs = rhs;
85 _adj_x = adj_x;
86 _adj_y = adj_y;
87 _output = output;
88}

◆ run()

void onert::backend::cpu::ops::BatchMatMulLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 90 of file BatchMatMulLayer.cc.

91{
92 if ((_lhs->data_type() == OperandType::FLOAT32) && (_rhs->data_type() == OperandType::FLOAT32))
93 {
95 }
96 else
97 {
98 throw std::runtime_error{"BatchMatMul: unsupported data type"};
99 }
100}
ir::DataType data_type() const override final

References batchMatMulFloat32(), and onert::backend::IPortableTensor::data_type().


The documentation for this class was generated from the following files: