18#ifndef ONERT_MICRO_EXECUTE_PAL_STRIDED_SLICE_H
19#define ONERT_MICRO_EXECUTE_PAL_STRIDED_SLICE_H
31inline int clamp(
const int v,
const int lo,
const int hi)
40inline bool loopCondition(
int index,
int stop,
int stride)
43 return stride > 0 ?
index >= stop :
index <= stop;
51inline int stopForAxis(
const core::StridedSliceParams ¶ms,
52 const core::OMRuntimeShape &input_shape,
int axis,
int start_for_axis)
54 const auto end_mask = params.end_mask;
55 const auto shrink_axis_mask = params.shrink_axis_mask;
56 const auto *stop_indices = params.stop_indices;
57 const auto *strides = params.strides;
58 const int axis_size = input_shape.dims(axis);
65 const bool shrink_axis = shrink_axis_mask & (1 << axis);
66 int stop = stop_indices[axis];
74 return start_for_axis + 1;
78 if (end_mask & (1 << axis))
80 if (strides[axis] > 0)
84 stop = std::numeric_limits<int>::max();
89 stop = std::numeric_limits<int>::lowest();
102 if (strides[axis] > 0)
105 stop = clamp(stop, 0, axis_size);
110 stop = clamp(stop, -1, axis_size - 1);
119inline int startForAxis(
const core::StridedSliceParams ¶ms,
120 const core::OMRuntimeShape &input_shape,
int axis)
122 const auto begin_mask = params.begin_mask;
123 const auto *start_indices = params.start_indices;
124 const auto *strides = params.strides;
125 const int axis_size = input_shape.dims(axis);
131 int start = start_indices[axis];
134 if (begin_mask & 1 << axis)
136 if (strides[axis] > 0)
141 start = std::numeric_limits<int>::lowest();
146 start = std::numeric_limits<int>::max();
157 if (strides[axis] > 0)
160 start = clamp(start, 0, axis_size);
165 start = clamp(start, -1, axis_size - 1);
171inline void stridedSlicePadIndices(core::StridedSliceParams *p,
int dim_count)
173 const int pad_count = dim_count - p->start_indices_count;
176 for (
int i = p->start_indices_count - 1; i >= 0; --i)
178 p->strides[i + pad_count] = p->strides[i];
179 p->start_indices[i + pad_count] = p->start_indices[i];
180 p->stop_indices[i + pad_count] = p->stop_indices[i];
182 for (
int i = 0; i < pad_count; ++i)
184 p->start_indices[i] = 0;
185 p->stop_indices[i] = 1;
190 p->shrink_axis_mask <<= pad_count;
191 p->ellipsis_mask <<= pad_count;
192 p->new_axis_mask <<= pad_count;
193 p->begin_mask <<= pad_count;
194 p->end_mask <<= pad_count;
195 p->begin_mask |= (1 << pad_count) - 1;
196 p->end_mask |= (1 << pad_count) - 1;
198 p->start_indices_count = dim_count;
199 p->stop_indices_count = dim_count;
200 p->strides_count = dim_count;
215 stridedSlicePadIndices(&op_params, 5);
217 const int start_0 = startForAxis(op_params, input_shape, 0);
218 const int stop_0 = stopForAxis(op_params, input_shape, 0, start_0);
219 const int start_1 = startForAxis(op_params, input_shape, 1);
220 const int stop_1 = stopForAxis(op_params, input_shape, 1, start_1);
221 const int start_2 = startForAxis(op_params, input_shape, 2);
222 const int stop_2 = stopForAxis(op_params, input_shape, 2, start_2);
223 const int start_3 = startForAxis(op_params, input_shape, 3);
224 const int stop_3 = stopForAxis(op_params, input_shape, 3, start_3);
225 const int start_4 = startForAxis(op_params, input_shape, 4);
226 const int stop_4 = stopForAxis(op_params, input_shape, 4, start_4);
228 for (
int offset_0 = start_0 * input_shape.
dims(1), end_0 = stop_0 * input_shape.
dims(1),
229 step_0 = op_params.
strides[0] * input_shape.
dims(1);
230 !loopCondition(offset_0, end_0, op_params.
strides[0]); offset_0 += step_0)
232 for (
int offset_1 = (offset_0 + start_1) * input_shape.
dims(2),
233 end_1 = (offset_0 + stop_1) * input_shape.
dims(2),
234 step_1 = op_params.
strides[1] * input_shape.
dims(2);
235 !loopCondition(offset_1, end_1, op_params.
strides[1]); offset_1 += step_1)
237 for (
int offset_2 = (offset_1 + start_2) * input_shape.
dims(3),
238 end_2 = (offset_1 + stop_2) * input_shape.
dims(3),
239 step_2 = op_params.
strides[2] * input_shape.
dims(3);
240 !loopCondition(offset_2, end_2, op_params.
strides[2]); offset_2 += step_2)
242 for (
int offset_3 = (offset_2 + start_3) * input_shape.
dims(4),
243 end_3 = (offset_2 + stop_3) * input_shape.
dims(4),
244 step_3 = op_params.
strides[3] * input_shape.
dims(4);
245 !loopCondition(offset_3, end_3, op_params.
strides[3]); offset_3 += step_3)
247 for (
int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
248 !loopCondition(offset_4, end_4, op_params.
strides[4]);
249 offset_4 += op_params.
strides[4])
251 *output_data++ = input_data[offset_4];
static OMRuntimeShape extendedShape(int new_shape_size, const OMRuntimeShape &shape)
int32_t dims(int i) const
loco::GraphInputIndex index(const TFPlaceholder *node)
OMStatus StridedSlice(core::StridedSliceParams &op_params, const core::OMRuntimeShape &unextended_input_shape, const T *input_data, T *output_data)