ONE - On-device Neural Engine
Loading...
Searching...
No Matches
StridedSlice.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
19
20#include "kernels/Utils.h"
21
22#include <tensorflow/lite/kernels/internal/reference/strided_slice.h>
23
24#include <stdexcept>
25
26namespace luci_interpreter
27{
28
29namespace kernels
30{
31
32StridedSlice::StridedSlice(const Tensor *input, const Tensor *begin, const Tensor *end,
33 const Tensor *strides, Tensor *output, const StridedSliceParams &params)
34 : KernelWithParams<StridedSliceParams>({input, begin, end, strides}, {output}, params)
35{
36}
37
39{
40 assert(begin()->shape().num_dims() == 1);
41 assert(end()->shape().num_dims() == 1);
42 assert(strides()->shape().num_dims() == 1);
43 assert(input()->element_type() == output()->element_type());
44 assert(begin()->element_type() == DataType::S32);
45 assert(end()->element_type() == DataType::S32);
46 assert(strides()->element_type() == DataType::S32);
47 assert(input()->shape().num_dims() <= 4);
48 if (params().ellipsis_mask != 0)
49 {
50 throw std::runtime_error("ellipsis_mask is not implemented yet.");
51 }
52 if (params().new_axis_mask != 0)
53 {
54 throw std::runtime_error("new_axis_mask is not implemented yet.");
55 }
56 if (input()->element_type() == DataType::U8)
57 {
58 assert(input()->scale() == output()->scale());
59 assert(input()->zero_point() == output()->zero_point());
60 }
61 tflite::StridedSliceParams op_params{};
62 op_params.start_indices_count = input()->shape().num_dims();
63 op_params.stop_indices_count = input()->shape().num_dims();
64 op_params.strides_count = input()->shape().num_dims();
65
66 for (int i = 0; i < input()->shape().num_dims(); i++)
67 {
68 op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
69 op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
70 op_params.strides[i] = getTensorData<int32_t>(strides())[i];
71 }
72 op_params.begin_mask = params().begin_mask;
73 op_params.ellipsis_mask = 0;
74 op_params.end_mask = params().end_mask;
75 op_params.new_axis_mask = 0;
76 op_params.shrink_axis_mask = params().shrink_axis_mask;
77 std::vector<int32_t> output_shape_vector;
78 for (int i = 0; i < input()->shape().num_dims(); i++)
79 {
80 int idx = input()->shape().num_dims() - i - 1;
81 int32_t stride = getTensorData<int32_t>(strides())[idx];
82 assert(stride != 0);
83 int32_t begin = ::tflite::strided_slice::StartForAxis(op_params, getTensorShape(input()), idx);
84 int32_t end =
85 ::tflite::strided_slice::StopForAxis(op_params, getTensorShape(input()), idx, begin);
86
87 const bool shrink_axis = params().shrink_axis_mask & (1 << idx);
88 if (shrink_axis)
89 {
90 end = begin + 1;
91 }
92
93 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
94 dim_shape = dim_shape < 0 ? 0 : dim_shape;
95 if (!shrink_axis)
96 {
97 output_shape_vector.push_back(dim_shape);
98 }
99 }
100 Shape output_shape = Shape(output_shape_vector.size());
101 for (size_t i = 0; i < output_shape_vector.size(); i++)
102 {
103 output_shape.dim(i) = output_shape_vector[output_shape_vector.size() - i - 1];
104 }
106}
107
109{
110 tflite::StridedSliceParams op_params{};
111 op_params.start_indices_count = input()->shape().num_dims();
112 op_params.stop_indices_count = input()->shape().num_dims();
113 op_params.strides_count = input()->shape().num_dims();
114
115 for (int i = 0; i < input()->shape().num_dims(); i++)
116 {
117 op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
118 op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
119 op_params.strides[i] = getTensorData<int32_t>(strides())[i];
120 }
121 op_params.begin_mask = params().begin_mask;
122 op_params.ellipsis_mask = 0;
123 op_params.end_mask = params().end_mask;
124 op_params.new_axis_mask = 0;
125 op_params.shrink_axis_mask = params().shrink_axis_mask;
126
127 switch (input()->element_type())
128 {
129 case DataType::FLOAT32:
130 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
131 getTensorData<float>(input()), getTensorShape(output()),
132 getTensorData<float>(output()));
133 break;
134 case DataType::U8:
135 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
136 getTensorData<uint8_t>(input()), getTensorShape(output()),
137 getTensorData<uint8_t>(output()));
138 break;
139 case DataType::S32:
140 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
141 getTensorData<int32_t>(input()), getTensorShape(output()),
142 getTensorData<int32_t>(output()));
143 break;
144 case DataType::S64:
145 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
146 getTensorData<int64_t>(input()), getTensorShape(output()),
147 getTensorData<int64_t>(output()));
148 break;
149 case DataType::BOOL:
150 tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
151 getTensorData<bool>(input()), getTensorShape(output()),
152 getTensorData<bool>(output()));
153 break;
154 default:
155 throw std::runtime_error("luci-intp StridedSlice Unsupported type.");
156 }
157}
158
159} // namespace kernels
160} // namespace luci_interpreter
const StridedSliceParams & params() const
Definition Kernel.h:67
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
StridedSlice(const Tensor *input, const Tensor *begin, const Tensor *end, const Tensor *strides, Tensor *output, const StridedSliceParams &params)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
int32_t begin[5]
Definition Slice.cpp:33