ONE - On-device Neural Engine
Loading...
Searching...
No Matches
StridedSlice.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_STRIDEDSLICE_H__
19#define __NNFW_CKER_STRIDEDSLICE_H__
20
21#include "cker/Shape.h"
22#include "cker/Types.h"
23#include "cker/Utils.h"
24
25#include <cmath>
26
27namespace nnfw
28{
29namespace cker
30{
31// Use until std::clamp() is available from C++17.
32inline int Clamp(const int v, const int lo, const int hi)
33{
34 assert(!(hi < lo));
35 if (hi < v)
36 return hi;
37 if (v < lo)
38 return lo;
39 return v;
40}
41
42inline void StridedSlicePadIndices(StridedSliceParams *p, int dim_count)
43{
44 // Add indices and mask bits to fully include extra dimensions
45 assert(dim_count <= 4);
46 assert(dim_count >= p->start_indices_count);
48 assert(p->stop_indices_count == p->strides_count);
49
50 const int pad_count = dim_count - p->start_indices_count;
51
52 // Pad indices at start, so move arrays by pad_count.
53 for (int i = p->start_indices_count - 1; i >= 0; --i)
54 {
55 p->strides[i + pad_count] = p->strides[i];
56 p->start_indices[i + pad_count] = p->start_indices[i];
57 p->stop_indices[i + pad_count] = p->stop_indices[i];
58 }
59 for (int i = 0; i < pad_count; ++i)
60 {
61 p->start_indices[i] = 0;
62 p->stop_indices[i] = 1;
63 p->strides[i] = 1;
64 }
65
66 // Pad masks with 0s or 1s as required.
67 p->shrink_axis_mask <<= pad_count;
68 p->ellipsis_mask <<= pad_count;
69 p->new_axis_mask <<= pad_count;
70 p->begin_mask <<= pad_count;
71 p->end_mask <<= pad_count;
72 p->begin_mask |= (1 << pad_count) - 1;
73 p->end_mask |= (1 << pad_count) - 1;
74
75 p->start_indices_count = dim_count;
76 p->stop_indices_count = dim_count;
77 p->strides_count = dim_count;
78}
79
80// Return the index for the first element along that axis. This index will be a
81// positive integer between [0, axis_size - 1] that can be used to index
82// directly into the data.
83inline int StartForAxis(const StridedSliceParams &params, const Shape &input_shape, int axis)
84{
85 const auto begin_mask = params.begin_mask;
86 const auto *start_indices = params.start_indices;
87 const auto *strides = params.strides;
88 // Begin with the specified index.
89 int start = start_indices[axis];
90
91 // begin_mask override
92 if (begin_mask & 1 << axis)
93 {
94 if (strides[axis] > 0)
95 {
96 // Forward iteration - use the first element. These values will get
97 // clamped below (Note: We could have set them to 0 and axis_size-1, but
98 // use lowest() and max() to maintain symmetry with StopForAxis())
99 start = std::numeric_limits<int>::lowest();
100 }
101 else
102 {
103 // Backward iteration - use the last element.
104 start = std::numeric_limits<int>::max();
105 }
106 }
107
108 // Handle negative indices
109 int axis_size = input_shape.Dims(axis);
110 if (start < 0)
111 {
112 start += axis_size;
113 }
114
115 // Clamping
116 start = Clamp(start, 0, axis_size - 1);
117
118 return start;
119}
120
121// Return the "real" index for the end of iteration along that axis. This is an
122// "end" in the traditional C sense, in that it points to one past the last
123// element. ie. So if you were iterating through all elements of a 1D array of
124// size 4, this function would return 4 as the stop, because it is one past the
125// "real" indices of 0, 1, 2 & 3.
126inline int StopForAxis(const StridedSliceParams &params, const Shape &input_shape, int axis,
127 int start_for_axis)
128{
129 const auto end_mask = params.end_mask;
130 const auto shrink_axis_mask = params.shrink_axis_mask;
131 const auto *stop_indices = params.stop_indices;
132 const auto *strides = params.strides;
133
134 // Begin with the specified index
135 const bool shrink_axis = shrink_axis_mask & (1 << axis);
136 int stop = stop_indices[axis];
137
138 // When shrinking an axis, the end position does not matter (and can be
139 // incorrect when negative indexing is used, see Issue #19260). Always use
140 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
141 // already been adjusted for negative indices.
142 if (shrink_axis)
143 {
144 stop = start_for_axis + 1;
145 }
146
147 // end_mask override
148 if (end_mask & (1 << axis))
149 {
150 if (strides[axis] > 0)
151 {
152 // Forward iteration - use the last element. These values will get
153 // clamped below
154 stop = std::numeric_limits<int>::max();
155 }
156 else
157 {
158 // Backward iteration - use the first element.
159 stop = std::numeric_limits<int>::lowest();
160 }
161 }
162
163 // Handle negative indices
164 const int axis_size = input_shape.Dims(axis);
165 if (stop < 0)
166 {
167 stop += axis_size;
168 }
169
170 // Clamping
171 // Because the end index points one past the last element, we need slightly
172 // different clamping ranges depending on the direction.
173 if (strides[axis] > 0)
174 {
175 // Forward iteration
176 stop = Clamp(stop, 0, axis_size);
177 }
178 else
179 {
180 // Backward iteration
181 stop = Clamp(stop, -1, axis_size - 1);
182 }
183
184 return stop;
185}
186
187inline bool LoopCondition(int index, int stop, int stride)
188{
189 // True when we have reached the end of an axis and should loop.
190 return stride > 0 ? index >= stop : index <= stop;
191}
192
193template <typename T>
194inline StridedSliceParams
195buildStridedSliceParams(const T *begin, const T *end, const T *strides, const uint32_t begin_mask,
196 const uint32_t end_mask, const uint32_t shrink_axis_mask,
197 const uint8_t rank)
198{
199 StridedSliceParams op_params;
200 op_params.start_indices_count = rank;
201 op_params.stop_indices_count = rank;
202 op_params.strides_count = rank;
203
204 for (int i = 0; i < rank; ++i)
205 {
206 op_params.start_indices[i] = begin[i];
207 op_params.stop_indices[i] = end[i];
208 op_params.strides[i] = strides[i];
209
210 assert(op_params.strides[i] != 0);
211 }
212
213 op_params.begin_mask = begin_mask;
214 op_params.ellipsis_mask = 0; // NYI
215 op_params.end_mask = end_mask;
216 op_params.new_axis_mask = 0; // NYI
217 op_params.shrink_axis_mask = shrink_axis_mask;
218
219 assert(sizeof(op_params.begin_mask) * 4 >= rank);
220
221 return op_params;
222}
223
224void checkOutputSize(const StridedSliceParams &op_params, const Shape &input_shape,
225 [[maybe_unused]] const Shape &output_shape, uint32_t rank)
226{
227 [[maybe_unused]] int32_t shape_size = 0;
228
229 for (uint32_t idx = 0; idx < rank; ++idx)
230 {
231 int32_t stride = op_params.strides[idx];
232 int32_t begin = StartForAxis(op_params, input_shape, idx);
233 int32_t end = StopForAxis(op_params, input_shape, idx, begin);
234
235 // When shrinking an axis, the end position does not matter (and can be
236 // incorrect when negative indexing is used, see Issue #19260). Always use
237 // begin + 1 to generate a length 1 slice, since begin has
238 // already been adjusted for negative indices by StartForAxis.
239 const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
240 if (shrink_axis)
241 {
242 end = begin + 1;
243 }
244
245 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
246 dim_shape = dim_shape < 0 ? 0 : dim_shape;
247 if (!shrink_axis)
248 {
249 assert(output_shape.Dims(shape_size) == dim_shape);
250 shape_size++;
251 }
252 }
253
254 assert(output_shape.DimensionsCount() == shape_size);
255}
256
257template <typename T>
258inline void StridedSlice(const StridedSliceParams &op_params, const Shape &unextended_input_shape,
259 const T *input_data, const Shape &unextended_output_shape, T *output_data)
260{
261 assert(unextended_input_shape.DimensionsCount() <= 4);
262 assert(unextended_output_shape.DimensionsCount() <= 4);
263
264 bool optimize = true;
265 int st_count = op_params.strides_count;
266 for (int idx = 0; idx < st_count - 1; idx++)
267 {
268 const int axis_size = unextended_input_shape.Dims(idx);
269 const int start = StartForAxis(op_params, unextended_input_shape, idx);
270 const int stop = StopForAxis(op_params, unextended_input_shape, idx, start);
271 if ((axis_size != 1) && (start != 0 || stop != 0))
272 {
273 optimize = false;
274 break;
275 }
276 }
277
278 if (optimize)
279 {
280 if (op_params.strides[st_count - 1] == 1)
281 {
282 const int start = StartForAxis(op_params, unextended_input_shape, st_count - 1);
283 const int end = StopForAxis(op_params, unextended_input_shape, st_count - 1, start);
284
285 for (int idx = 0; idx < end - start; idx++)
286 {
287 output_data[idx] = input_data[idx + start];
288 }
289 return;
290 }
291 }
292
293 // Note that the output_shape is not used herein.
294 StridedSliceParams params_copy = op_params;
295
296 const Shape input_shape = Shape::ExtendedShape(4, unextended_input_shape);
297 const Shape output_shape = Shape::ExtendedShape(4, unextended_output_shape);
298
299 // Reverse and pad to 4 dimensions because that is what the runtime code
300 // requires (ie. all shapes must be 4D and are given backwards).
301 StridedSlicePadIndices(&params_copy, 4);
302
303 const int start_b = StartForAxis(params_copy, input_shape, 0);
304 const int stop_b = StopForAxis(params_copy, input_shape, 0, start_b);
305 const int start_h = StartForAxis(params_copy, input_shape, 1);
306 const int stop_h = StopForAxis(params_copy, input_shape, 1, start_h);
307 const int start_w = StartForAxis(params_copy, input_shape, 2);
308 const int stop_w = StopForAxis(params_copy, input_shape, 2, start_w);
309 const int start_d = StartForAxis(params_copy, input_shape, 3);
310 const int stop_d = StopForAxis(params_copy, input_shape, 3, start_d);
311
312 T *out_ptr = output_data;
313 for (int in_b = start_b; !LoopCondition(in_b, stop_b, params_copy.strides[0]);
314 in_b += params_copy.strides[0])
315 {
316 for (int in_h = start_h; !LoopCondition(in_h, stop_h, params_copy.strides[1]);
317 in_h += params_copy.strides[1])
318 {
319 for (int in_w = start_w; !LoopCondition(in_w, stop_w, params_copy.strides[2]);
320 in_w += params_copy.strides[2])
321 {
322 for (int in_d = start_d; !LoopCondition(in_d, stop_d, params_copy.strides[3]);
323 in_d += params_copy.strides[3])
324 {
325 *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
326 }
327 }
328 }
329 }
330}
331
332} // namespace cker
333} // namespace nnfw
334
335#endif // __NNFW_CKER_STRIDEDSLICE_H__
int32_t DimensionsCount() const
Definition Shape.h:91
int32_t Dims(int i) const
Definition Shape.h:92
const luci_interpreter::RuntimeShape output_shape
bool LoopCondition(int index, int stop, int stride)
int Clamp(const int v, const int lo, const int hi)
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
Definition Shape.h:237
ShapeIterator end(const Shape &s)
int StopForAxis(const StridedSliceParams &params, const Shape &input_shape, int axis, int start_for_axis)
void StridedSlice(const StridedSliceParams &op_params, const Shape &unextended_input_shape, const T *input_data, const Shape &unextended_output_shape, T *output_data)
void StridedSlicePadIndices(StridedSliceParams *p, int dim_count)
void checkOutputSize(const StridedSliceParams &op_params, const Shape &input_shape, const Shape &output_shape, uint32_t rank)
StridedSliceParams buildStridedSliceParams(const T *begin, const T *end, const T *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
int StartForAxis(const StridedSliceParams &params, const Shape &input_shape, int axis)
Definition topk_v2.h:30
int32_t begin[5]
Definition Slice.cpp:33