ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALStridedSlice.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef ONERT_MICRO_EXECUTE_PAL_STRIDED_SLICE_H
19#define ONERT_MICRO_EXECUTE_PAL_STRIDED_SLICE_H
20
21namespace onert_micro
22{
23namespace execute
24{
25namespace pal
26{
27
28namespace
29{
30// Use until std::clamp() is available from C++17.
31inline int clamp(const int v, const int lo, const int hi)
32{
33 if (hi < v)
34 return hi;
35 if (v < lo)
36 return lo;
37 return v;
38}
39
40inline bool loopCondition(int index, int stop, int stride)
41{
42 // True when we have reached the end of an axis and should loop.
43 return stride > 0 ? index >= stop : index <= stop;
44}
45
46// Return the "real" index for the end of iteration along that axis. This is an
47// "end" in the traditional C sense, in that it points to one past the last
48// element. ie. So if you were iterating through all elements of a 1D array of
49// size 4, this function would return 4 as the stop, because it is one past the
50// "real" indices of 0, 1, 2 & 3.
51inline int stopForAxis(const core::StridedSliceParams &params,
52 const core::OMRuntimeShape &input_shape, int axis, int start_for_axis)
53{
54 const auto end_mask = params.end_mask;
55 const auto shrink_axis_mask = params.shrink_axis_mask;
56 const auto *stop_indices = params.stop_indices;
57 const auto *strides = params.strides;
58 const int axis_size = input_shape.dims(axis);
59 if (axis_size == 0)
60 {
61 return 0;
62 }
63
64 // Begin with the specified index
65 const bool shrink_axis = shrink_axis_mask & (1 << axis);
66 int stop = stop_indices[axis];
67
68 // When shrinking an axis, the end position does not matter (and can be
69 // incorrect when negative indexing is used, see Issue #19260). Always use
70 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
71 // already been adjusted for negative indices.
72 if (shrink_axis)
73 {
74 return start_for_axis + 1;
75 }
76
77 // end_mask override
78 if (end_mask & (1 << axis))
79 {
80 if (strides[axis] > 0)
81 {
82 // Forward iteration - use the last element. These values will get
83 // clamped below
84 stop = std::numeric_limits<int>::max();
85 }
86 else
87 {
88 // Backward iteration - use the first element.
89 stop = std::numeric_limits<int>::lowest();
90 }
91 }
92
93 // Handle negative indices
94 if (stop < 0)
95 {
96 stop += axis_size;
97 }
98
99 // Clamping
100 // Because the end index points one past the last element, we need slightly
101 // different clamping ranges depending on the direction.
102 if (strides[axis] > 0)
103 {
104 // Forward iteration
105 stop = clamp(stop, 0, axis_size);
106 }
107 else
108 {
109 // Backward iteration
110 stop = clamp(stop, -1, axis_size - 1);
111 }
112
113 return stop;
114}
115
116// Return the index for the first element along that axis. This index will be a
117// positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
118// that can be used to index directly into the data.
119inline int startForAxis(const core::StridedSliceParams &params,
120 const core::OMRuntimeShape &input_shape, int axis)
121{
122 const auto begin_mask = params.begin_mask;
123 const auto *start_indices = params.start_indices;
124 const auto *strides = params.strides;
125 const int axis_size = input_shape.dims(axis);
126 if (axis_size == 0)
127 {
128 return 0;
129 }
130 // Begin with the specified index.
131 int start = start_indices[axis];
132
133 // begin_mask override
134 if (begin_mask & 1 << axis)
135 {
136 if (strides[axis] > 0)
137 {
138 // Forward iteration - use the first element. These values will get
139 // clamped below (Note: We could have set them to 0 and axis_size-1, but
140 // use lowest() and max() to maintain symmetry with StopForAxis())
141 start = std::numeric_limits<int>::lowest();
142 }
143 else
144 {
145 // Backward iteration - use the last element.
146 start = std::numeric_limits<int>::max();
147 }
148 }
149
150 // Handle negative indices
151 if (start < 0)
152 {
153 start += axis_size;
154 }
155
156 // Clamping
157 if (strides[axis] > 0)
158 {
159 // Forward iteration
160 start = clamp(start, 0, axis_size);
161 }
162 else
163 {
164 // Backward iteration
165 start = clamp(start, -1, axis_size - 1);
166 }
167
168 return start;
169}
170
171inline void stridedSlicePadIndices(core::StridedSliceParams *p, int dim_count)
172{
173 const int pad_count = dim_count - p->start_indices_count;
174
175 // Pad indices at start, so move arrays by pad_count.
176 for (int i = p->start_indices_count - 1; i >= 0; --i)
177 {
178 p->strides[i + pad_count] = p->strides[i];
179 p->start_indices[i + pad_count] = p->start_indices[i];
180 p->stop_indices[i + pad_count] = p->stop_indices[i];
181 }
182 for (int i = 0; i < pad_count; ++i)
183 {
184 p->start_indices[i] = 0;
185 p->stop_indices[i] = 1;
186 p->strides[i] = 1;
187 }
188
189 // Pad masks with 0s or 1s as required.
190 p->shrink_axis_mask <<= pad_count;
191 p->ellipsis_mask <<= pad_count;
192 p->new_axis_mask <<= pad_count;
193 p->begin_mask <<= pad_count;
194 p->end_mask <<= pad_count;
195 p->begin_mask |= (1 << pad_count) - 1;
196 p->end_mask |= (1 << pad_count) - 1;
197
198 p->start_indices_count = dim_count;
199 p->stop_indices_count = dim_count;
200 p->strides_count = dim_count;
201}
202
203} // namespace
204
205template <typename T>
207 const core::OMRuntimeShape &unextended_input_shape, const T *input_data,
208 T *output_data)
209{
210 const core::OMRuntimeShape input_shape =
211 core::OMRuntimeShape::extendedShape(5, unextended_input_shape);
212
213 // Reverse and pad to 5 dimensions because that is what the runtime code
214 // requires (ie. all shapes must be 5D and are given backwards).
215 stridedSlicePadIndices(&op_params, 5);
216
217 const int start_0 = startForAxis(op_params, input_shape, 0);
218 const int stop_0 = stopForAxis(op_params, input_shape, 0, start_0);
219 const int start_1 = startForAxis(op_params, input_shape, 1);
220 const int stop_1 = stopForAxis(op_params, input_shape, 1, start_1);
221 const int start_2 = startForAxis(op_params, input_shape, 2);
222 const int stop_2 = stopForAxis(op_params, input_shape, 2, start_2);
223 const int start_3 = startForAxis(op_params, input_shape, 3);
224 const int stop_3 = stopForAxis(op_params, input_shape, 3, start_3);
225 const int start_4 = startForAxis(op_params, input_shape, 4);
226 const int stop_4 = stopForAxis(op_params, input_shape, 4, start_4);
227
228 for (int offset_0 = start_0 * input_shape.dims(1), end_0 = stop_0 * input_shape.dims(1),
229 step_0 = op_params.strides[0] * input_shape.dims(1);
230 !loopCondition(offset_0, end_0, op_params.strides[0]); offset_0 += step_0)
231 {
232 for (int offset_1 = (offset_0 + start_1) * input_shape.dims(2),
233 end_1 = (offset_0 + stop_1) * input_shape.dims(2),
234 step_1 = op_params.strides[1] * input_shape.dims(2);
235 !loopCondition(offset_1, end_1, op_params.strides[1]); offset_1 += step_1)
236 {
237 for (int offset_2 = (offset_1 + start_2) * input_shape.dims(3),
238 end_2 = (offset_1 + stop_2) * input_shape.dims(3),
239 step_2 = op_params.strides[2] * input_shape.dims(3);
240 !loopCondition(offset_2, end_2, op_params.strides[2]); offset_2 += step_2)
241 {
242 for (int offset_3 = (offset_2 + start_3) * input_shape.dims(4),
243 end_3 = (offset_2 + stop_3) * input_shape.dims(4),
244 step_3 = op_params.strides[3] * input_shape.dims(4);
245 !loopCondition(offset_3, end_3, op_params.strides[3]); offset_3 += step_3)
246 {
247 for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
248 !loopCondition(offset_4, end_4, op_params.strides[4]);
249 offset_4 += op_params.strides[4])
250 {
251 *output_data++ = input_data[offset_4];
252 }
253 }
254 }
255 }
256 }
257 return Ok;
258}
259
260} // namespace pal
261} // namespace execute
262} // namespace onert_micro
263
264#endif // ONERT_MICRO_EXECUTE_PAL_STRIDED_SLICE_H
static OMRuntimeShape extendedShape(int new_shape_size, const OMRuntimeShape &shape)
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
OMStatus StridedSlice(core::StridedSliceParams &op_params, const core::OMRuntimeShape &unextended_input_shape, const T *input_data, T *output_data)