ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::cpu::ops::SplitLayer Class Reference

#include <SplitLayer.h>

Collaboration diagram for onert::backend::cpu::ops::SplitLayer:

Public Member Functions

 SplitLayer ()
 
template<typename T >
void split (void)
 
void configure (const IPortableTensor *input, const IPortableTensor *axis, uint16_t num_splits, std::vector< IPortableTensor * > &outputs)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Detailed Description

Definition at line 33 of file SplitLayer.h.

Constructor & Destructor Documentation

◆ SplitLayer()

onert::backend::cpu::ops::SplitLayer::SplitLayer ( )

Definition at line 32 of file SplitLayer.cc.

32 : _input(nullptr), _axis(nullptr), _num_splits(0), _outputs()
33{
34 // DO NOTHING
35}

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::SplitLayer::configure ( const IPortableTensor input,
const IPortableTensor axis,
uint16_t  num_splits,
std::vector< IPortableTensor * > &  outputs 
)

Definition at line 65 of file SplitLayer.cc.

67{
68 assert(input != nullptr);
69
70 _num_splits = num_splits;
71 _input = input;
72 _axis = axis;
73 _outputs = outputs;
74}

◆ run()

void onert::backend::cpu::ops::SplitLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 76 of file SplitLayer.cc.

77{
78 if (_input->data_type() == OperandType::FLOAT32)
79 {
80 split<float>();
81 }
82 else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
83 {
84 split<uint8_t>();
85 }
86 else if (_input->data_type() == OperandType::INT32)
87 {
88 split<int32_t>();
89 }
90 else if (_input->data_type() == OperandType::INT64)
91 {
92 split<int64_t>();
93 }
94 else
95 {
96 throw std::runtime_error{"Split: unsupported input type"};
97 }
98}
ir::DataType data_type() const override final

References onert::backend::IPortableTensor::data_type().

Referenced by package.infer.session::inference().

◆ split()

template<typename T >
void onert::backend::cpu::ops::SplitLayer::split ( void  )

Definition at line 37 of file SplitLayer.cc.

38{
40 if (_axis->total_size() != sizeof(int32_t))
41 {
42 throw std::runtime_error("ArgMinMax: wrong shape of axis");
43 }
44 auto axis = *getBuffer<int32_t>(_axis);
45 if (axis < 0)
46 {
47 axis += _input->getShape().rank();
48 }
49 op_params.axis = axis;
50 op_params.num_split = _num_splits;
51
52 std::vector<T *> outputPtrs;
53
54 for (const auto output : _outputs)
55 {
56 assert(output->total_size() == sizeOfData(output->data_type(), output->getShape().dims()));
57 outputPtrs.emplace_back(getBuffer<T>(output));
58 }
59
60 assert(_input->total_size() == sizeOfData(_input->data_type(), _input->getShape().dims()));
61 nnfw::cker::Split<T>(op_params, getShape(_input), getBuffer<T>(_input), getShape(_outputs[0]),
62 outputPtrs.data());
63}
size_t total_size() const override final
ir::Shape getShape() const override final
Get ir::Shape of tensor.
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
uint32_t sizeOfData(OperandType type, const std::vector< int32_t > &dimensions)

References nnfw::cker::SplitParams::axis, onert::backend::IPortableTensor::data_type(), onert::backend::IPortableTensor::getShape(), onert::backend::cpu::ops::getShape(), nnfw::cker::SplitParams::num_split, onert::backend::cpu::ops::sizeOfData(), and onert::backend::IPortableTensor::total_size().


The documentation for this class was generated from the following files: