ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::Slice Class Reference

#include <Slice.h>

Collaboration diagram for luci_interpreter::kernels::Slice:

Public Member Functions

 Slice (const Tensor *input, const Tensor *begin, const Tensor *size, Tensor *output)
 
const Tensorinput () const
 
const Tensorbegin () const
 
const Tensorsize () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 27 of file Slice.h.

Constructor & Destructor Documentation

◆ Slice()

luci_interpreter::kernels::Slice::Slice ( const Tensor input,
const Tensor begin,
const Tensor size,
Tensor output 
)

Definition at line 31 of file Slice.cpp.

32 : Kernel({input, begin, size}, {output})
33{
34}
Kernel(std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
Definition Kernel.h:31
const Tensor * begin() const
Definition Slice.h:33
Tensor * output() const
Definition Slice.h:35
const Tensor * input() const
Definition Slice.h:32
const Tensor * size() const
Definition Slice.h:34

References begin(), input(), and size().

Member Function Documentation

◆ begin()

const Tensor * luci_interpreter::kernels::Slice::begin ( ) const
inline

Definition at line 33 of file Slice.h.

33{ return _inputs[1]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Slice().

◆ configure()

void luci_interpreter::kernels::Slice::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 74 of file Slice.cpp.

75{
76 assert(input()->element_type() == output()->element_type());
77 assert(begin()->element_type() == DataType::S32 || begin()->element_type() == DataType::S64);
78 assert(size()->element_type() == DataType::S32 || size()->element_type() == DataType::S64);
79 assert(begin()->shape().num_dims() == 1);
80 assert(size()->shape().num_dims() == 1);
81 assert(input()->shape().num_dims() <= max_dim);
82
83 if (begin()->element_type() == DataType::S32)
84 {
85 output()->resize(calculateOutputShape<int32_t>(input(), begin(), size()));
86 }
87 else if (begin()->element_type() == DataType::S64)
88 {
89 output()->resize(calculateOutputShape<int64_t>(input(), begin(), size()));
90 }
91 else
92 {
93 throw std::runtime_error("luci-intp Slice Unsupported type.");
94 }
95}
void resize(const Shape &new_shape)
Definition Tensor.cpp:56

References begin(), input(), luci_interpreter::kernels::max_dim, output(), luci_interpreter::Tensor::resize(), and size().

◆ execute()

void luci_interpreter::kernels::Slice::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 97 of file Slice.cpp.

98{
99 std::vector<int> begins;
100 begins.reserve(max_dim);
101 std::vector<int> sizes;
102 sizes.reserve(max_dim);
103 if (begin()->element_type() == DataType::S32)
104 {
105 getBeginAndSizeVectors<int32_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
106 }
107 else if (begin()->element_type() == DataType::S64)
108 {
109 getBeginAndSizeVectors<int64_t>(input()->shape().num_dims(), begin(), size(), &begins, &sizes);
110 }
111 else
112 {
113 throw std::runtime_error("Unsupported begin type.");
114 }
115 for (int i = input()->shape().num_dims(); i < max_dim; ++i)
116 {
117 begins.push_back(0);
118 sizes.push_back(1);
119 }
120
121 assert(begins.size() == 4);
122 assert(sizes.size() == 4);
123 tflite::SliceParams op_params{};
124 op_params.begin_count = 4;
125 op_params.size_count = 4;
126 for (int i = 0; i < 4; i++)
127 {
128 op_params.begin[i] = begins[3 - i];
129 op_params.size[i] = sizes[3 - i];
130 }
131 switch (input()->element_type())
132 {
133 case DataType::FLOAT32:
134 luci_interpreter_pal::Slice(op_params, getTensorShape(input()), getTensorData<float>(input()),
135 getTensorShape(output()), getTensorData<float>(output()));
136 break;
137 case DataType::U8:
138 luci_interpreter_pal::Slice(op_params, getTensorShape(input()),
139 getTensorData<uint8_t>(input()), getTensorShape(output()),
140 getTensorData<uint8_t>(output()));
141 break;
142 case DataType::S8:
143 luci_interpreter_pal::Slice(op_params, getTensorShape(input()),
144 getTensorData<int8_t>(input()), getTensorShape(output()),
145 getTensorData<int8_t>(output()));
146 break;
147 default:
148 throw std::runtime_error("Unsupported input type.");
149 }
150}
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194

References begin(), luci_interpreter::kernels::getTensorShape(), input(), luci_interpreter::kernels::max_dim, output(), and size().

◆ input()

const Tensor * luci_interpreter::kernels::Slice::input ( ) const
inline

Definition at line 32 of file Slice.h.

32{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Slice().

◆ output()

Tensor * luci_interpreter::kernels::Slice::output ( ) const
inline

Definition at line 35 of file Slice.h.

35{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure(), and execute().

◆ size()

const Tensor * luci_interpreter::kernels::Slice::size ( ) const
inline

Definition at line 34 of file Slice.h.

34{ return _inputs[2]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Slice().


The documentation for this class was generated from the following files: