ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::StridedSlice Class Reference

#include <StridedSlice.h>

Collaboration diagram for luci_interpreter::kernels::StridedSlice:

Public Member Functions

 StridedSlice (const Tensor *input, const Tensor *begin, const Tensor *end, const Tensor *strides, Tensor *output, const StridedSliceParams &params)
 
const Tensorinput () const
 
const Tensorbegin () const
 
const Tensorend () const
 
const Tensorstrides () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< StridedSliceParams >
const StridedSliceParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< StridedSliceParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const StridedSliceParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< StridedSliceParams >
const StridedSliceParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file StridedSlice.h.

Constructor & Destructor Documentation

◆ StridedSlice()

luci_interpreter::kernels::StridedSlice::StridedSlice ( const Tensor input,
const Tensor begin,
const Tensor end,
const Tensor strides,
Tensor output,
const StridedSliceParams params 
)

Definition at line 32 of file StridedSlice.cpp.

34 : KernelWithParams<StridedSliceParams>({input, begin, end, strides}, {output}, params)
35{
36}
const StridedSliceParams & params() const
Definition Kernel.h:67

References begin(), end(), input(), and strides().

Member Function Documentation

◆ begin()

const Tensor * luci_interpreter::kernels::StridedSlice::begin ( ) const
inline

Definition at line 35 of file StridedSlice.h.

35{ return _inputs[1]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and StridedSlice().

◆ configure()

void luci_interpreter::kernels::StridedSlice::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 38 of file StridedSlice.cpp.

39{
40 assert(begin()->shape().num_dims() == 1);
41 assert(end()->shape().num_dims() == 1);
42 assert(strides()->shape().num_dims() == 1);
43 assert(input()->element_type() == output()->element_type());
44 assert(begin()->element_type() == DataType::S32);
45 assert(end()->element_type() == DataType::S32);
46 assert(strides()->element_type() == DataType::S32);
47 assert(input()->shape().num_dims() <= 4);
48 if (params().ellipsis_mask != 0)
49 {
50 throw std::runtime_error("ellipsis_mask is not implemented yet.");
51 }
52 if (params().new_axis_mask != 0)
53 {
54 throw std::runtime_error("new_axis_mask is not implemented yet.");
55 }
56 if (input()->element_type() == DataType::U8)
57 {
58 assert(input()->scale() == output()->scale());
59 assert(input()->zero_point() == output()->zero_point());
60 }
61 tflite::StridedSliceParams op_params{};
62 op_params.start_indices_count = input()->shape().num_dims();
63 op_params.stop_indices_count = input()->shape().num_dims();
64 op_params.strides_count = input()->shape().num_dims();
65
66 for (int i = 0; i < input()->shape().num_dims(); i++)
67 {
68 op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
69 op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
70 op_params.strides[i] = getTensorData<int32_t>(strides())[i];
71 }
72 op_params.begin_mask = params().begin_mask;
73 op_params.ellipsis_mask = 0;
74 op_params.end_mask = params().end_mask;
75 op_params.new_axis_mask = 0;
76 op_params.shrink_axis_mask = params().shrink_axis_mask;
77 std::vector<int32_t> output_shape_vector;
78 for (int i = 0; i < input()->shape().num_dims(); i++)
79 {
80 int idx = input()->shape().num_dims() - i - 1;
81 int32_t stride = getTensorData<int32_t>(strides())[idx];
82 assert(stride != 0);
83 int32_t begin = ::tflite::strided_slice::StartForAxis(op_params, getTensorShape(input()), idx);
84 int32_t end =
85 ::tflite::strided_slice::StopForAxis(op_params, getTensorShape(input()), idx, begin);
86
87 const bool shrink_axis = params().shrink_axis_mask & (1 << idx);
88 if (shrink_axis)
89 {
90 end = begin + 1;
91 }
92
93 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
94 dim_shape = dim_shape < 0 ? 0 : dim_shape;
95 if (!shrink_axis)
96 {
97 output_shape_vector.push_back(dim_shape);
98 }
99 }
100 Shape output_shape = Shape(output_shape_vector.size());
101 for (size_t i = 0; i < output_shape_vector.size(); i++)
102 {
103 output_shape.dim(i) = output_shape_vector[output_shape_vector.size() - i - 1];
104 }
106}
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
Definition Shape.h:28

References begin(), luci_interpreter::StridedSliceParams::begin_mask, end(), luci_interpreter::StridedSliceParams::end_mask, luci_interpreter::kernels::getTensorShape(), input(), luci_interpreter::Shape::num_dims(), output(), output_shape, luci_interpreter::KernelWithParams< StridedSliceParams >::params(), luci_interpreter::Tensor::resize(), luci_interpreter::Tensor::shape(), luci_interpreter::StridedSliceParams::shrink_axis_mask, and strides().

◆ end()

const Tensor * luci_interpreter::kernels::StridedSlice::end ( ) const
inline

Definition at line 36 of file StridedSlice.h.

36{ return _inputs[2]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and StridedSlice().

◆ execute()

void luci_interpreter::kernels::StridedSlice::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 108 of file StridedSlice.cpp.

109{
110 tflite::StridedSliceParams op_params{};
111 op_params.start_indices_count = input()->shape().num_dims();
112 op_params.stop_indices_count = input()->shape().num_dims();
113 op_params.strides_count = input()->shape().num_dims();
114
115 for (int i = 0; i < input()->shape().num_dims(); i++)
116 {
117 op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
118 op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
119 op_params.strides[i] = getTensorData<int32_t>(strides())[i];
120 }
121 op_params.begin_mask = params().begin_mask;
122 op_params.ellipsis_mask = 0;
123 op_params.end_mask = params().end_mask;
124 op_params.new_axis_mask = 0;
125 op_params.shrink_axis_mask = params().shrink_axis_mask;
126
127 switch (input()->element_type())
128 {
129 case DataType::FLOAT32:
130 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
131 getTensorData<float>(input()), getTensorShape(output()),
132 getTensorData<float>(output()));
133 break;
134 case DataType::U8:
135 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
136 getTensorData<uint8_t>(input()), getTensorShape(output()),
137 getTensorData<uint8_t>(output()));
138 break;
139 case DataType::S32:
140 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
141 getTensorData<int32_t>(input()), getTensorShape(output()),
142 getTensorData<int32_t>(output()));
143 break;
144 case DataType::S64:
145 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
146 getTensorData<int64_t>(input()), getTensorShape(output()),
147 getTensorData<int64_t>(output()));
148 break;
149 case DataType::BOOL:
150 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
151 getTensorData<bool>(input()), getTensorShape(output()),
152 getTensorData<bool>(output()));
153 break;
154 default:
155 throw std::runtime_error("luci-intp StridedSlice Unsupported type.");
156 }
157}

References begin(), luci_interpreter::StridedSliceParams::begin_mask, end(), luci_interpreter::StridedSliceParams::end_mask, luci_interpreter::kernels::getTensorShape(), input(), luci_interpreter::Shape::num_dims(), output(), luci_interpreter::KernelWithParams< StridedSliceParams >::params(), luci_interpreter::Tensor::shape(), luci_interpreter::StridedSliceParams::shrink_axis_mask, and strides().

◆ input()

const Tensor * luci_interpreter::kernels::StridedSlice::input ( ) const
inline

Definition at line 34 of file StridedSlice.h.

34{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and StridedSlice().

◆ output()

Tensor * luci_interpreter::kernels::StridedSlice::output ( ) const
inline

Definition at line 38 of file StridedSlice.h.

38{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure(), and execute().

◆ strides()

const Tensor * luci_interpreter::kernels::StridedSlice::strides ( ) const
inline

Definition at line 37 of file StridedSlice.h.

37{ return _inputs[3]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and StridedSlice().


The documentation for this class was generated from the following files: