ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::cpu::ops::SplitLayer Class Reference

#include <SplitLayer.h>

Collaboration diagram for onert::backend::cpu::ops::SplitLayer:

Public Member Functions

 SplitLayer ()
 
template<typename T >
void split (void)
 
void configure (const IPortableTensor *input, const IPortableTensor *axis, uint16_t num_splits, std::vector< IPortableTensor * > &outputs)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Detailed Description

Definition at line 27 of file SplitLayer.h.

Constructor & Destructor Documentation

◆ SplitLayer()

onert::backend::cpu::ops::SplitLayer::SplitLayer ( )

Definition at line 57 of file SplitLayer.cc.

57 : _input(nullptr), _axis(nullptr), _num_splits(0), _outputs()
58{
59 // DO NOTHING
60}

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::SplitLayer::configure ( const IPortableTensor input,
const IPortableTensor axis,
uint16_t  num_splits,
std::vector< IPortableTensor * > &  outputs 
)

Definition at line 90 of file SplitLayer.cc.

92{
93 assert(input != nullptr);
94
95 _num_splits = num_splits;
96 _input = input;
97 _axis = axis;
98 _outputs = outputs;
99}

◆ run()

void onert::backend::cpu::ops::SplitLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 101 of file SplitLayer.cc.

102{
103 if (_input->data_type() == OperandType::FLOAT32)
104 {
105 split<float>();
106 }
107 else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
108 {
109 split<uint8_t>();
110 }
111 else if (_input->data_type() == OperandType::INT32)
112 {
113 split<int32_t>();
114 }
115 else if (_input->data_type() == OperandType::INT64)
116 {
117 split<int64_t>();
118 }
119 else
120 {
121 throw std::runtime_error{"Split: unsupported input type"};
122 }
123}
ir::DataType data_type() const override final

References onert::backend::IPortableTensor::data_type().

◆ split()

template<typename T >
void onert::backend::cpu::ops::SplitLayer::split ( void  )

Definition at line 62 of file SplitLayer.cc.

63{
65 if (_axis->total_size() != sizeof(int32_t))
66 {
67 throw std::runtime_error("ArgMinMax: wrong shape of axis");
68 }
69 auto axis = *getBuffer<int32_t>(_axis);
70 if (axis < 0)
71 {
72 axis += _input->getShape().rank();
73 }
74 op_params.axis = axis;
75 op_params.num_split = _num_splits;
76
77 std::vector<T *> outputPtrs;
78
79 for (const auto output : _outputs)
80 {
81 assert(output->total_size() == sizeOfData(output->data_type(), output->getShape().dims()));
82 outputPtrs.emplace_back(getBuffer<T>(output));
83 }
84
85 assert(_input->total_size() == sizeOfData(_input->data_type(), _input->getShape().dims()));
86 nnfw::cker::Split<T>(op_params, getShape(_input), getBuffer<T>(_input), getShape(_outputs[0]),
87 outputPtrs.data());
88}
size_t total_size() const override final
ir::Shape getShape() const override final
Get ir::Shape of tensor.
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
uint32_t sizeOfData(OperandType type, const std::vector< int32_t > &dimensions)

References nnfw::cker::SplitParams::axis, onert::backend::IPortableTensor::data_type(), onert::backend::IPortableTensor::getShape(), onert::backend::cpu::ops::getShape(), nnfw::cker::SplitParams::num_split, onert::backend::cpu::ops::sizeOfData(), and onert::backend::IPortableTensor::total_size().


The documentation for this class was generated from the following files: