ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::cpu::ops::SplitVLayer Class Reference

#include <SplitVLayer.h>

Collaboration diagram for onert::backend::cpu::ops::SplitVLayer:

Public Member Functions

 SplitVLayer ()
 
template<typename T >
void splitV (void)
 
void configure (const IPortableTensor *input, const IPortableTensor *size_splits, const IPortableTensor *size_dim, uint16_t num_splits, std::vector< IPortableTensor * > &outputs)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Detailed Description

Definition at line 33 of file SplitVLayer.h.

Constructor & Destructor Documentation

◆ SplitVLayer()

onert::backend::cpu::ops::SplitVLayer::SplitVLayer ( )

Definition at line 32 of file SplitVLayer.cc.

33 : _input(nullptr), _size_splits(nullptr), _split_dim(nullptr), _num_splits(0), _outputs()
34{
35 // DO NOTHING
36}

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::SplitVLayer::configure ( const IPortableTensor input,
const IPortableTensor size_splits,
const IPortableTensor size_dim,
uint16_t  num_splits,
std::vector< IPortableTensor * > &  outputs 
)

Definition at line 59 of file SplitVLayer.cc.

62{
63 assert(input != nullptr);
64
65 _num_splits = num_splits;
66 _size_splits = size_splits;
67 _input = input;
68 _split_dim = split_dim;
69 _outputs = outputs;
70}

◆ run()

void onert::backend::cpu::ops::SplitVLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 72 of file SplitVLayer.cc.

73{
74 if (_input->data_type() == OperandType::FLOAT32)
75 {
76 splitV<float>();
77 }
78 else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
79 {
80 splitV<uint8_t>();
81 }
82 else if (_input->data_type() == OperandType::INT32)
83 {
84 splitV<int32_t>();
85 }
86 else if (_input->data_type() == OperandType::INT64)
87 {
88 splitV<int64_t>();
89 }
90 else
91 {
92 throw std::runtime_error{"SplitV: unsupported input type"};
93 }
94}
ir::DataType data_type() const override final

References onert::backend::IPortableTensor::data_type().

Referenced by package.infer.session::inference().

◆ splitV()

template<typename T >
void onert::backend::cpu::ops::SplitVLayer::splitV ( void  )

Definition at line 38 of file SplitVLayer.cc.

39{
41 op_params.axis = *getBuffer<int32_t>(_split_dim);
42 op_params.num_split = _num_splits;
43
44 std::vector<T *> outputPtrs;
45 std::vector<nnfw::cker::Shape> outshape;
46
47 for (const auto output : _outputs)
48 {
49 assert(output->total_size() == sizeOfData(output->data_type(), output->getShape().dims()));
50 outputPtrs.emplace_back(getBuffer<T>(output));
51 outshape.emplace_back(getShape(output));
52 }
53
54 assert(_input->total_size() == sizeOfData(_input->data_type(), _input->getShape().dims()));
55 nnfw::cker::SplitV<T>(op_params, getShape(_input), getBuffer<T>(_input), outshape,
56 outputPtrs.data());
57}
size_t total_size() const override final
ir::Shape getShape() const override final
Get ir::Shape of tensor.
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
uint32_t sizeOfData(OperandType type, const std::vector< int32_t > &dimensions)

References nnfw::cker::SplitVParams::axis, onert::backend::IPortableTensor::data_type(), onert::backend::IPortableTensor::getShape(), onert::backend::cpu::ops::getShape(), nnfw::cker::SplitVParams::num_split, onert::backend::cpu::ops::sizeOfData(), and onert::backend::IPortableTensor::total_size().


The documentation for this class was generated from the following files: