ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::cpu::ops::MatrixBandPartLayer Class Reference

#include <MatrixBandPartLayer.h>

Collaboration diagram for onert::backend::cpu::ops::MatrixBandPartLayer:

Public Member Functions

 MatrixBandPartLayer ()
 
void matrixBandPartFloat32 ()
 
void matrixBandPartQuant8 ()
 
void configure (const IPortableTensor *input, const IPortableTensor *num_lower_diag, const IPortableTensor *num_upper_diag, IPortableTensor *output)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Detailed Description

Definition at line 33 of file MatrixBandPartLayer.h.

Constructor & Destructor Documentation

◆ MatrixBandPartLayer()

onert::backend::cpu::ops::MatrixBandPartLayer::MatrixBandPartLayer ( )

Definition at line 32 of file MatrixBandPartLayer.cc.

33 : _input(nullptr), _num_lower_diag(nullptr), _num_upper_diag(nullptr), _output(nullptr)
34{
35 // DO NOTHING
36}

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::MatrixBandPartLayer::configure ( const IPortableTensor input,
const IPortableTensor num_lower_diag,
const IPortableTensor num_upper_diag,
IPortableTensor output 
)

Definition at line 56 of file MatrixBandPartLayer.cc.

59{
60 _input = input;
61 _num_lower_diag = num_lower_diag;
62 _num_upper_diag = num_upper_diag;
63 _output = output;
64}

◆ matrixBandPartFloat32()

void onert::backend::cpu::ops::MatrixBandPartLayer::matrixBandPartFloat32 ( )

Definition at line 38 of file MatrixBandPartLayer.cc.

39{
40 if (_num_lower_diag->data_type() == OperandType::INT64)
41 {
42 nnfw::cker::MatrixBandPart<int64_t>(
43 *getBuffer<int64_t>(_num_lower_diag), *getBuffer<int64_t>(_num_upper_diag), getShape(_input),
44 getBuffer<float>(_input), getShape(_output), getBuffer<float>(_output));
45 }
46 else
47 {
48 nnfw::cker::MatrixBandPart<int32_t>(
49 *getBuffer<int32_t>(_num_lower_diag), *getBuffer<int32_t>(_num_upper_diag), getShape(_input),
50 getBuffer<float>(_input), getShape(_output), getBuffer<float>(_output));
51 }
52}
ir::DataType data_type() const override final
nnfw::cker::Shape getShape(const IPortableTensor *tensor)

References onert::backend::IPortableTensor::data_type(), and onert::backend::cpu::ops::getShape().

Referenced by run().

◆ matrixBandPartQuant8()

void onert::backend::cpu::ops::MatrixBandPartLayer::matrixBandPartQuant8 ( )

Definition at line 54 of file MatrixBandPartLayer.cc.

54{ throw std::runtime_error{"NYI"}; }

Referenced by run().

◆ run()

void onert::backend::cpu::ops::MatrixBandPartLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 66 of file MatrixBandPartLayer.cc.

67{
68 if (_num_lower_diag->data_type() != _num_upper_diag->data_type())
69 {
70 throw std::runtime_error{"MatrixBandpart: num_lower and num_upper must have the same type"};
71 }
72
73 if (_input->data_type() == OperandType::FLOAT32)
74 {
76 }
77 else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
78 {
80 }
81 else
82 {
83 throw std::runtime_error{"MatrixBandpart: unsupported data type"};
84 }
85}

References onert::backend::IPortableTensor::data_type(), matrixBandPartFloat32(), and matrixBandPartQuant8().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: