ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALStridedSlice.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_STRIDED_SLICE_H
19#define LUCI_INTERPRETER_PAL_STRIDED_SLICE_H
20
21#include "Params.h"
22
24{
25
26namespace
27{
28// Use until std::clamp() is available from C++17.
29inline int clamp(const int v, const int lo, const int hi)
30{
31 if (hi < v)
32 return hi;
33 if (v < lo)
34 return lo;
35 return v;
36}
37
38inline bool loopCondition(int index, int stop, int stride)
39{
40 // True when we have reached the end of an axis and should loop.
41 return stride > 0 ? index >= stop : index <= stop;
42}
43
44// Return the "real" index for the end of iteration along that axis. This is an
45// "end" in the traditional C sense, in that it points to one past the last
46// element. ie. So if you were iterating through all elements of a 1D array of
47// size 4, this function would return 4 as the stop, because it is one past the
48// "real" indices of 0, 1, 2 & 3.
49inline int stopForAxis(const StridedSliceParams &params,
50 const luci_interpreter::RuntimeShape &input_shape, int axis,
51 int start_for_axis)
52{
53 const auto end_mask = params.end_mask;
54 const auto shrink_axis_mask = params.shrink_axis_mask;
55 const auto *stop_indices = params.stop_indices;
56 const auto *strides = params.strides;
57 const int axis_size = input_shape.dims(axis);
58 if (axis_size == 0)
59 {
60 return 0;
61 }
62
63 // Begin with the specified index
64 const bool shrink_axis = shrink_axis_mask & (1 << axis);
65 int stop = stop_indices[axis];
66
67 // When shrinking an axis, the end position does not matter (and can be
68 // incorrect when negative indexing is used, see Issue #19260). Always use
69 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
70 // already been adjusted for negative indices.
71 if (shrink_axis)
72 {
73 return start_for_axis + 1;
74 }
75
76 // end_mask override
77 if (end_mask & (1 << axis))
78 {
79 if (strides[axis] > 0)
80 {
81 // Forward iteration - use the last element. These values will get
82 // clamped below
83 stop = std::numeric_limits<int>::max();
84 }
85 else
86 {
87 // Backward iteration - use the first element.
88 stop = std::numeric_limits<int>::lowest();
89 }
90 }
91
92 // Handle negative indices
93 if (stop < 0)
94 {
95 stop += axis_size;
96 }
97
98 // Clamping
99 // Because the end index points one past the last element, we need slightly
100 // different clamping ranges depending on the direction.
101 if (strides[axis] > 0)
102 {
103 // Forward iteration
104 stop = clamp(stop, 0, axis_size);
105 }
106 else
107 {
108 // Backward iteration
109 stop = clamp(stop, -1, axis_size - 1);
110 }
111
112 return stop;
113}
114
115// Return the index for the first element along that axis. This index will be a
116// positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
117// that can be used to index directly into the data.
118inline int startForAxis(const StridedSliceParams &params,
119 const luci_interpreter::RuntimeShape &input_shape, int axis)
120{
121 const auto begin_mask = params.begin_mask;
122 const auto *start_indices = params.start_indices;
123 const auto *strides = params.strides;
124 const int axis_size = input_shape.dims(axis);
125 if (axis_size == 0)
126 {
127 return 0;
128 }
129 // Begin with the specified index.
130 int start = start_indices[axis];
131
132 // begin_mask override
133 if (begin_mask & 1 << axis)
134 {
135 if (strides[axis] > 0)
136 {
137 // Forward iteration - use the first element. These values will get
138 // clamped below (Note: We could have set them to 0 and axis_size-1, but
139 // use lowest() and max() to maintain symmetry with StopForAxis())
140 start = std::numeric_limits<int>::lowest();
141 }
142 else
143 {
144 // Backward iteration - use the last element.
145 start = std::numeric_limits<int>::max();
146 }
147 }
148
149 // Handle negative indices
150 if (start < 0)
151 {
152 start += axis_size;
153 }
154
155 // Clamping
156 if (strides[axis] > 0)
157 {
158 // Forward iteration
159 start = clamp(start, 0, axis_size);
160 }
161 else
162 {
163 // Backward iteration
164 start = clamp(start, -1, axis_size - 1);
165 }
166
167 return start;
168}
169
170inline void stridedSlicePadIndices(StridedSliceParams *p, int dim_count)
171{
172 const int pad_count = dim_count - p->start_indices_count;
173
174 // Pad indices at start, so move arrays by pad_count.
175 for (int i = p->start_indices_count - 1; i >= 0; --i)
176 {
177 p->strides[i + pad_count] = p->strides[i];
178 p->start_indices[i + pad_count] = p->start_indices[i];
179 p->stop_indices[i + pad_count] = p->stop_indices[i];
180 }
181 for (int i = 0; i < pad_count; ++i)
182 {
183 p->start_indices[i] = 0;
184 p->stop_indices[i] = 1;
185 p->strides[i] = 1;
186 }
187
188 // Pad masks with 0s or 1s as required.
189 p->shrink_axis_mask <<= pad_count;
190 p->ellipsis_mask <<= pad_count;
191 p->new_axis_mask <<= pad_count;
192 p->begin_mask <<= pad_count;
193 p->end_mask <<= pad_count;
194 p->begin_mask |= (1 << pad_count) - 1;
195 p->end_mask |= (1 << pad_count) - 1;
196
197 p->start_indices_count = dim_count;
198 p->stop_indices_count = dim_count;
199 p->strides_count = dim_count;
200}
201
202} // namespace
203
204template <typename T>
205inline void StridedSlice(StridedSliceParams &op_params,
206 const luci_interpreter::RuntimeShape &unextended_input_shape,
207 const T *input_data, T *output_data)
208{
209 const luci_interpreter::RuntimeShape input_shape =
210 luci_interpreter::RuntimeShape::extendedShape(5, unextended_input_shape);
211
212 // Reverse and pad to 5 dimensions because that is what the runtime code
213 // requires (ie. all shapes must be 5D and are given backwards).
214 stridedSlicePadIndices(&op_params, 5);
215
216 const int start_0 = startForAxis(op_params, input_shape, 0);
217 const int stop_0 = stopForAxis(op_params, input_shape, 0, start_0);
218 const int start_1 = startForAxis(op_params, input_shape, 1);
219 const int stop_1 = stopForAxis(op_params, input_shape, 1, start_1);
220 const int start_2 = startForAxis(op_params, input_shape, 2);
221 const int stop_2 = stopForAxis(op_params, input_shape, 2, start_2);
222 const int start_3 = startForAxis(op_params, input_shape, 3);
223 const int stop_3 = stopForAxis(op_params, input_shape, 3, start_3);
224 const int start_4 = startForAxis(op_params, input_shape, 4);
225 const int stop_4 = stopForAxis(op_params, input_shape, 4, start_4);
226
227 for (int offset_0 = start_0 * input_shape.dims(1), end_0 = stop_0 * input_shape.dims(1),
228 step_0 = op_params.strides[0] * input_shape.dims(1);
229 !loopCondition(offset_0, end_0, op_params.strides[0]); offset_0 += step_0)
230 {
231 for (int offset_1 = (offset_0 + start_1) * input_shape.dims(2),
232 end_1 = (offset_0 + stop_1) * input_shape.dims(2),
233 step_1 = op_params.strides[1] * input_shape.dims(2);
234 !loopCondition(offset_1, end_1, op_params.strides[1]); offset_1 += step_1)
235 {
236 for (int offset_2 = (offset_1 + start_2) * input_shape.dims(3),
237 end_2 = (offset_1 + stop_2) * input_shape.dims(3),
238 step_2 = op_params.strides[2] * input_shape.dims(3);
239 !loopCondition(offset_2, end_2, op_params.strides[2]); offset_2 += step_2)
240 {
241 for (int offset_3 = (offset_2 + start_3) * input_shape.dims(4),
242 end_3 = (offset_2 + stop_3) * input_shape.dims(4),
243 step_3 = op_params.strides[3] * input_shape.dims(4);
244 !loopCondition(offset_3, end_3, op_params.strides[3]); offset_3 += step_3)
245 {
246 for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
247 !loopCondition(offset_4, end_4, op_params.strides[4]);
248 offset_4 += op_params.strides[4])
249 {
250 *output_data++ = input_data[offset_4];
251 }
252 }
253 }
254 }
255 }
256}
257
258} // namespace luci_interpreter_pal
259
260#endif // LUCI_INTERPRETER_PAL_STRIDED_SLICE_H
int32_t dims(int i) const
Definition Tensor.h:108
static RuntimeShape extendedShape(int new_shape_size, const RuntimeShape &shape)
Definition Tensor.h:95
void StridedSlice(StridedSliceParams &op_params, const luci_interpreter::RuntimeShape &unextended_input_shape, const T *input_data, T *output_data)
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54