ONE - On-device Neural Engine
Loading...
Searching...
No Matches
StridedSlice.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_STRIDEDSLICE_H__
19#define __NNFW_CKER_STRIDEDSLICE_H__
20
21#include "cker/Shape.h"
22#include "cker/Types.h"
23#include "cker/Utils.h"
24
25#include <cmath>
26
27namespace nnfw
28{
29namespace cker
30{
31// Use until std::clamp() is available from C++17.
32inline int Clamp(const int v, const int lo, const int hi)
33{
34 assert(!(hi < lo));
35 if (hi < v)
36 return hi;
37 if (v < lo)
38 return lo;
39 return v;
40}
41
42inline void StridedSlicePadIndices(StridedSliceParams *p, int dim_count)
43{
44 // Add indices and mask bits to fully include extra dimensions
45 assert(dim_count <= 5);
46 assert(dim_count >= p->start_indices_count);
47 assert(p->start_indices_count == p->stop_indices_count);
48 assert(p->stop_indices_count == p->strides_count);
49
50 const int pad_count = dim_count - p->start_indices_count;
51
52 // Pad indices at start, so move arrays by pad_count.
53 for (int i = p->start_indices_count - 1; i >= 0; --i)
54 {
55 p->strides[i + pad_count] = p->strides[i];
56 p->start_indices[i + pad_count] = p->start_indices[i];
57 p->stop_indices[i + pad_count] = p->stop_indices[i];
58 }
59 for (int i = 0; i < pad_count; ++i)
60 {
61 p->start_indices[i] = 0;
62 p->stop_indices[i] = 1;
63 p->strides[i] = 1;
64 }
65
66 // Pad masks with 0s or 1s as required.
67 p->shrink_axis_mask <<= pad_count;
68 p->ellipsis_mask <<= pad_count;
69 p->new_axis_mask <<= pad_count;
70 p->begin_mask <<= pad_count;
71 p->end_mask <<= pad_count;
72 p->begin_mask |= (1 << pad_count) - 1;
73 p->end_mask |= (1 << pad_count) - 1;
74
75 p->start_indices_count = dim_count;
76 p->stop_indices_count = dim_count;
77 p->strides_count = dim_count;
78}
79
80// Return the index for the first element along that axis. This index will be a
81// positive integer between [0, axis_size - 1] that can be used to index
82// directly into the data.
83inline int StartForAxis(const StridedSliceParams &params, const Shape &input_shape, int axis)
84{
85 const auto begin_mask = params.begin_mask;
86 const auto *start_indices = params.start_indices;
87 const auto *strides = params.strides;
88 // Begin with the specified index.
89 int start = start_indices[axis];
90
91 // begin_mask override
92 if (begin_mask & 1 << axis)
93 {
94 if (strides[axis] > 0)
95 {
96 // Forward iteration - use the first element. These values will get
97 // clamped below (Note: We could have set them to 0 and axis_size-1, but
98 // use lowest() and max() to maintain symmetry with StopForAxis())
99 start = std::numeric_limits<int>::lowest();
100 }
101 else
102 {
103 // Backward iteration - use the last element.
104 start = std::numeric_limits<int>::max();
105 }
106 }
107
108 // Handle negative indices
109 int axis_size = input_shape.Dims(axis);
110 if (start < 0)
111 {
112 start += axis_size;
113 }
114
115 // Clamping
116 start = Clamp(start, 0, axis_size - 1);
117
118 return start;
119}
120
121// Return the "real" index for the end of iteration along that axis. This is an
122// "end" in the traditional C sense, in that it points to one past the last
123// element. ie. So if you were iterating through all elements of a 1D array of
124// size 4, this function would return 4 as the stop, because it is one past the
125// "real" indices of 0, 1, 2 & 3.
126inline int StopForAxis(const StridedSliceParams &params, const Shape &input_shape, int axis,
127 int start_for_axis)
128{
129 const auto end_mask = params.end_mask;
130 const auto shrink_axis_mask = params.shrink_axis_mask;
131 const auto *stop_indices = params.stop_indices;
132 const auto *strides = params.strides;
133
134 // Begin with the specified index
135 const bool shrink_axis = shrink_axis_mask & (1 << axis);
136 int stop = stop_indices[axis];
137
138 // When shrinking an axis, the end position does not matter (and can be
139 // incorrect when negative indexing is used, see Issue #19260). Always use
140 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
141 // already been adjusted for negative indices.
142 if (shrink_axis)
143 {
144 stop = start_for_axis + 1;
145 }
146
147 // end_mask override
148 if (end_mask & (1 << axis))
149 {
150 if (strides[axis] > 0)
151 {
152 // Forward iteration - use the last element. These values will get
153 // clamped below
154 stop = std::numeric_limits<int>::max();
155 }
156 else
157 {
158 // Backward iteration - use the first element.
159 stop = std::numeric_limits<int>::lowest();
160 }
161 }
162
163 // Handle negative indices
164 const int axis_size = input_shape.Dims(axis);
165 if (stop < 0)
166 {
167 stop += axis_size;
168 }
169
170 // Clamping
171 // Because the end index points one past the last element, we need slightly
172 // different clamping ranges depending on the direction.
173 if (strides[axis] > 0)
174 {
175 // Forward iteration
176 stop = Clamp(stop, 0, axis_size);
177 }
178 else
179 {
180 // Backward iteration
181 stop = Clamp(stop, -1, axis_size - 1);
182 }
183
184 return stop;
185}
186
187inline bool LoopCondition(int index, int stop, int stride)
188{
189 // True when we have reached the end of an axis and should loop.
190 return stride > 0 ? index >= stop : index <= stop;
191}
192
193template <typename T>
194inline StridedSliceParams
195buildStridedSliceParams(const T *begin, const T *end, const T *strides, const uint32_t begin_mask,
196 const uint32_t end_mask, const uint32_t shrink_axis_mask,
197 const uint8_t rank)
198{
199 StridedSliceParams op_params;
200 op_params.start_indices_count = rank;
201 op_params.stop_indices_count = rank;
202 op_params.strides_count = rank;
203
204 for (int i = 0; i < rank; ++i)
205 {
206 op_params.start_indices[i] = begin[i];
207 op_params.stop_indices[i] = end[i];
208 op_params.strides[i] = strides[i];
209
210 assert(op_params.strides[i] != 0);
211 }
212
213 op_params.begin_mask = begin_mask;
214 op_params.ellipsis_mask = 0; // NYI
215 op_params.end_mask = end_mask;
216 op_params.new_axis_mask = 0; // NYI
217 op_params.shrink_axis_mask = shrink_axis_mask;
218
219 assert(sizeof(op_params.begin_mask) * 4 >= rank);
220
221 return op_params;
222}
223
224void checkOutputSize(const StridedSliceParams &op_params, const Shape &input_shape,
225 [[maybe_unused]] const Shape &output_shape, uint32_t rank)
226{
227 [[maybe_unused]] int32_t shape_size = 0;
228
229 for (uint32_t idx = 0; idx < rank; ++idx)
230 {
231 int32_t stride = op_params.strides[idx];
232 int32_t begin = StartForAxis(op_params, input_shape, idx);
233 int32_t end = StopForAxis(op_params, input_shape, idx, begin);
234
235 // When shrinking an axis, the end position does not matter (and can be
236 // incorrect when negative indexing is used, see Issue #19260). Always use
237 // begin + 1 to generate a length 1 slice, since begin has
238 // already been adjusted for negative indices by StartForAxis.
239 const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
240 if (shrink_axis)
241 {
242 end = begin + 1;
243 }
244
245 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
246 dim_shape = dim_shape < 0 ? 0 : dim_shape;
247 if (!shrink_axis)
248 {
249 assert(output_shape.Dims(shape_size) == dim_shape);
250 shape_size++;
251 }
252 }
253
254 assert(output_shape.DimensionsCount() == shape_size);
255}
256
257template <typename T>
258inline void StridedSlice(const StridedSliceParams &op_params, const Shape &unextended_input_shape,
259 const T *input_data, const Shape &unextended_output_shape, T *output_data)
260{
261 assert(unextended_input_shape.DimensionsCount() <= 5);
262 assert(unextended_output_shape.DimensionsCount() <= 5);
263
264 bool optimize = true;
265 int st_count = op_params.strides_count;
266 for (int idx = 0; idx < st_count - 1; idx++)
267 {
268 const int axis_size = unextended_input_shape.Dims(idx);
269 const int start = StartForAxis(op_params, unextended_input_shape, idx);
270 const int stop = StopForAxis(op_params, unextended_input_shape, idx, start);
271 if ((axis_size != 1) && (start != 0 || stop != 0))
272 {
273 optimize = false;
274 break;
275 }
276 }
277
278 if (optimize)
279 {
280 if (op_params.strides[st_count - 1] == 1)
281 {
282 const int start = StartForAxis(op_params, unextended_input_shape, st_count - 1);
283 const int end = StopForAxis(op_params, unextended_input_shape, st_count - 1, start);
284
285 for (int idx = 0; idx < end - start; idx++)
286 {
287 output_data[idx] = input_data[idx + start];
288 }
289 return;
290 }
291 }
292
293 // Note that the output_shape is not used herein.
294 StridedSliceParams params_copy = op_params;
295
296 const Shape input_shape = Shape::ExtendedShape(5, unextended_input_shape);
297 const Shape output_shape = Shape::ExtendedShape(5, unextended_output_shape);
298
299 // Reverse and pad to 5 dimensions because that is what the runtime code
300 // requires (ie. all shapes must be 5D and are given backwards).
301 StridedSlicePadIndices(&params_copy, 5);
302
303 const int start_0 = StartForAxis(params_copy, input_shape, 0);
304 const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
305 const int start_1 = StartForAxis(params_copy, input_shape, 1);
306 const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
307 const int start_2 = StartForAxis(params_copy, input_shape, 2);
308 const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
309 const int start_3 = StartForAxis(params_copy, input_shape, 3);
310 const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
311 const int start_4 = StartForAxis(params_copy, input_shape, 4);
312 const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
313
314 T *out_ptr = output_data;
315 for (int in_0 = start_0; !LoopCondition(in_0, stop_0, params_copy.strides[0]);
316 in_0 += params_copy.strides[0])
317 {
318 for (int in_1 = start_1; !LoopCondition(in_1, stop_1, params_copy.strides[1]);
319 in_1 += params_copy.strides[1])
320 {
321 for (int in_2 = start_2; !LoopCondition(in_2, stop_2, params_copy.strides[2]);
322 in_2 += params_copy.strides[2])
323 {
324 for (int in_3 = start_3; !LoopCondition(in_3, stop_3, params_copy.strides[3]);
325 in_3 += params_copy.strides[3])
326 {
327 for (int in_4 = start_4; !LoopCondition(in_4, stop_4, params_copy.strides[4]);
328 in_4 += params_copy.strides[4])
329 {
330 *out_ptr++ = input_data[Offset(input_shape, in_0, in_1, in_2, in_3, in_4)];
331 }
332 }
333 }
334 }
335 }
336}
337
338} // namespace cker
339} // namespace nnfw
340
341#endif // __NNFW_CKER_STRIDEDSLICE_H__
int32_t DimensionsCount() const
Definition Shape.h:103
int32_t Dims(int i) const
Definition Shape.h:106
const luci_interpreter::RuntimeShape output_shape
bool LoopCondition(int index, int stop, int stride)
int Clamp(const int v, const int lo, const int hi)
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
Definition Shape.h:325
ShapeIterator end(const Shape &s)
int StopForAxis(const StridedSliceParams &params, const Shape &input_shape, int axis, int start_for_axis)
void StridedSlice(const StridedSliceParams &op_params, const Shape &unextended_input_shape, const T *input_data, const Shape &unextended_output_shape, T *output_data)
void StridedSlicePadIndices(StridedSliceParams *p, int dim_count)
void checkOutputSize(const StridedSliceParams &op_params, const Shape &input_shape, const Shape &output_shape, uint32_t rank)
StridedSliceParams buildStridedSliceParams(const T *begin, const T *end, const T *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
int StartForAxis(const StridedSliceParams &params, const Shape &input_shape, int axis)
Definition topk_v2.h:30
int32_t begin[5]
Definition Slice.cpp:33
Configuration p