ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::cpu::ops::MatrixBandPartLayer Class Reference

#include <MatrixBandPartLayer.h>

Collaboration diagram for onert::backend::cpu::ops::MatrixBandPartLayer:

Public Member Functions

 MatrixBandPartLayer ()
 
void matrixBandPartFloat32 ()
 
void matrixBandPartQuant8 ()
 
void configure (const IPortableTensor *input, const IPortableTensor *num_lower_diag, const IPortableTensor *num_upper_diag, IPortableTensor *output)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Detailed Description

Definition at line 27 of file MatrixBandPartLayer.h.

Constructor & Destructor Documentation

◆ MatrixBandPartLayer()

onert::backend::cpu::ops::MatrixBandPartLayer::MatrixBandPartLayer ( )

Definition at line 26 of file MatrixBandPartLayer.cc.

27 : _input(nullptr), _num_lower_diag(nullptr), _num_upper_diag(nullptr), _output(nullptr)
28{
29 // DO NOTHING
30}

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::MatrixBandPartLayer::configure ( const IPortableTensor input,
const IPortableTensor num_lower_diag,
const IPortableTensor num_upper_diag,
IPortableTensor output 
)

Definition at line 50 of file MatrixBandPartLayer.cc.

53{
54 _input = input;
55 _num_lower_diag = num_lower_diag;
56 _num_upper_diag = num_upper_diag;
57 _output = output;
58}

◆ matrixBandPartFloat32()

void onert::backend::cpu::ops::MatrixBandPartLayer::matrixBandPartFloat32 ( )

Definition at line 32 of file MatrixBandPartLayer.cc.

33{
34 if (_num_lower_diag->data_type() == OperandType::INT64)
35 {
36 nnfw::cker::MatrixBandPart<int64_t>(
37 *getBuffer<int64_t>(_num_lower_diag), *getBuffer<int64_t>(_num_upper_diag), getShape(_input),
38 getBuffer<float>(_input), getShape(_output), getBuffer<float>(_output));
39 }
40 else
41 {
42 nnfw::cker::MatrixBandPart<int32_t>(
43 *getBuffer<int32_t>(_num_lower_diag), *getBuffer<int32_t>(_num_upper_diag), getShape(_input),
44 getBuffer<float>(_input), getShape(_output), getBuffer<float>(_output));
45 }
46}
ir::DataType data_type() const override final
nnfw::cker::Shape getShape(const IPortableTensor *tensor)

References onert::backend::IPortableTensor::data_type(), and onert::backend::cpu::ops::getShape().

Referenced by run().

◆ matrixBandPartQuant8()

void onert::backend::cpu::ops::MatrixBandPartLayer::matrixBandPartQuant8 ( )

Definition at line 48 of file MatrixBandPartLayer.cc.

48{ throw std::runtime_error{"NYI"}; }

Referenced by run().

◆ run()

void onert::backend::cpu::ops::MatrixBandPartLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 60 of file MatrixBandPartLayer.cc.

61{
62 if (_num_lower_diag->data_type() != _num_upper_diag->data_type())
63 {
64 throw std::runtime_error{"MatrixBandpart: num_lower and num_upper must have the same type"};
65 }
66
67 if (_input->data_type() == OperandType::FLOAT32)
68 {
70 }
71 else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
72 {
74 }
75 else
76 {
77 throw std::runtime_error{"MatrixBandpart: unsupported data type"};
78 }
79}

References onert::backend::IPortableTensor::data_type(), matrixBandPartFloat32(), and matrixBandPartQuant8().


The documentation for this class was generated from the following files: