ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::cpu::ops::SplitLayer Class Reference

#include <SplitLayer.h>

Collaboration diagram for onert::backend::cpu::ops::SplitLayer:

Public Member Functions

 SplitLayer ()
 
template<typename T >
void split (void)
 
void configure (const IPortableTensor *input, const IPortableTensor *axis, uint16_t num_splits, std::vector< IPortableTensor * > &outputs)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Detailed Description

Definition at line 27 of file SplitLayer.h.

Constructor & Destructor Documentation

◆ SplitLayer()

onert::backend::cpu::ops::SplitLayer::SplitLayer ( )

Definition at line 26 of file SplitLayer.cc.

26 : _input(nullptr), _axis(nullptr), _num_splits(0), _outputs()
27{
28 // DO NOTHING
29}

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::SplitLayer::configure ( const IPortableTensor input,
const IPortableTensor axis,
uint16_t  num_splits,
std::vector< IPortableTensor * > &  outputs 
)

Definition at line 59 of file SplitLayer.cc.

61{
62 assert(input != nullptr);
63
64 _num_splits = num_splits;
65 _input = input;
66 _axis = axis;
67 _outputs = outputs;
68}

◆ run()

void onert::backend::cpu::ops::SplitLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 70 of file SplitLayer.cc.

71{
72 if (_input->data_type() == OperandType::FLOAT32)
73 {
74 split<float>();
75 }
76 else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
77 {
78 split<uint8_t>();
79 }
80 else if (_input->data_type() == OperandType::INT32)
81 {
82 split<int32_t>();
83 }
84 else if (_input->data_type() == OperandType::INT64)
85 {
86 split<int64_t>();
87 }
88 else
89 {
90 throw std::runtime_error{"Split: unsupported input type"};
91 }
92}
ir::DataType data_type() const override final

References onert::backend::IPortableTensor::data_type().

◆ split()

template<typename T >
void onert::backend::cpu::ops::SplitLayer::split ( void  )

Definition at line 31 of file SplitLayer.cc.

32{
34 if (_axis->total_size() != sizeof(int32_t))
35 {
36 throw std::runtime_error("ArgMinMax: wrong shape of axis");
37 }
38 auto axis = *getBuffer<int32_t>(_axis);
39 if (axis < 0)
40 {
41 axis += _input->getShape().rank();
42 }
43 op_params.axis = axis;
44 op_params.num_split = _num_splits;
45
46 std::vector<T *> outputPtrs;
47
48 for (const auto output : _outputs)
49 {
50 assert(output->total_size() == sizeOfData(output->data_type(), output->getShape().dims()));
51 outputPtrs.emplace_back(getBuffer<T>(output));
52 }
53
54 assert(_input->total_size() == sizeOfData(_input->data_type(), _input->getShape().dims()));
55 nnfw::cker::Split<T>(op_params, getShape(_input), getBuffer<T>(_input), getShape(_outputs[0]),
56 outputPtrs.data());
57}
size_t total_size() const override final
ir::Shape getShape() const override final
Get ir::Shape of tensor.
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
uint32_t sizeOfData(OperandType type, const std::vector< int32_t > &dimensions)

References nnfw::cker::SplitParams::axis, onert::backend::IPortableTensor::data_type(), onert::backend::IPortableTensor::getShape(), onert::backend::cpu::ops::getShape(), nnfw::cker::SplitParams::num_split, onert::backend::cpu::ops::sizeOfData(), and onert::backend::IPortableTensor::total_size().


The documentation for this class was generated from the following files: