45 assert(dim_count <= 5);
46 assert(dim_count >=
p->start_indices_count);
47 assert(
p->start_indices_count ==
p->stop_indices_count);
48 assert(
p->stop_indices_count ==
p->strides_count);
50 const int pad_count = dim_count -
p->start_indices_count;
53 for (
int i =
p->start_indices_count - 1; i >= 0; --i)
55 p->strides[i + pad_count] =
p->strides[i];
56 p->start_indices[i + pad_count] =
p->start_indices[i];
57 p->stop_indices[i + pad_count] =
p->stop_indices[i];
59 for (
int i = 0; i < pad_count; ++i)
61 p->start_indices[i] = 0;
62 p->stop_indices[i] = 1;
67 p->shrink_axis_mask <<= pad_count;
68 p->ellipsis_mask <<= pad_count;
69 p->new_axis_mask <<= pad_count;
70 p->begin_mask <<= pad_count;
71 p->end_mask <<= pad_count;
72 p->begin_mask |= (1 << pad_count) - 1;
73 p->end_mask |= (1 << pad_count) - 1;
75 p->start_indices_count = dim_count;
76 p->stop_indices_count = dim_count;
77 p->strides_count = dim_count;
129 const auto end_mask = params.
end_mask;
132 const auto *strides = params.
strides;
135 const bool shrink_axis = shrink_axis_mask & (1 << axis);
136 int stop = stop_indices[axis];
144 stop = start_for_axis + 1;
148 if (end_mask & (1 << axis))
150 if (strides[axis] > 0)
154 stop = std::numeric_limits<int>::max();
159 stop = std::numeric_limits<int>::lowest();
164 const int axis_size = input_shape.
Dims(axis);
173 if (strides[axis] > 0)
176 stop =
Clamp(stop, 0, axis_size);
181 stop =
Clamp(stop, -1, axis_size - 1);
227 [[maybe_unused]] int32_t shape_size = 0;
229 for (uint32_t idx = 0; idx < rank; ++idx)
231 int32_t stride = op_params.
strides[idx];
245 int32_t dim_shape = std::ceil((
end -
begin) /
static_cast<float>(stride));
246 dim_shape = dim_shape < 0 ? 0 : dim_shape;
259 const T *input_data,
const Shape &unextended_output_shape, T *output_data)
264 bool optimize =
true;
266 for (
int idx = 0; idx < st_count - 1; idx++)
268 const int axis_size = unextended_input_shape.
Dims(idx);
269 const int start =
StartForAxis(op_params, unextended_input_shape, idx);
270 const int stop =
StopForAxis(op_params, unextended_input_shape, idx, start);
271 if ((axis_size != 1) && (start != 0 || stop != 0))
280 if (op_params.
strides[st_count - 1] == 1)
282 const int start =
StartForAxis(op_params, unextended_input_shape, st_count - 1);
283 const int end =
StopForAxis(op_params, unextended_input_shape, st_count - 1, start);
285 for (
int idx = 0; idx <
end - start; idx++)
287 output_data[idx] = input_data[idx + start];
296 const Shape input_shape = Shape::ExtendedShape(5, unextended_input_shape);
303 const int start_0 =
StartForAxis(params_copy, input_shape, 0);
304 const int stop_0 =
StopForAxis(params_copy, input_shape, 0, start_0);
305 const int start_1 =
StartForAxis(params_copy, input_shape, 1);
306 const int stop_1 =
StopForAxis(params_copy, input_shape, 1, start_1);
307 const int start_2 =
StartForAxis(params_copy, input_shape, 2);
308 const int stop_2 =
StopForAxis(params_copy, input_shape, 2, start_2);
309 const int start_3 =
StartForAxis(params_copy, input_shape, 3);
310 const int stop_3 =
StopForAxis(params_copy, input_shape, 3, start_3);
311 const int start_4 =
StartForAxis(params_copy, input_shape, 4);
312 const int stop_4 =
StopForAxis(params_copy, input_shape, 4, start_4);
314 T *out_ptr = output_data;
316 in_0 += params_copy.
strides[0])
319 in_1 += params_copy.
strides[1])
322 in_2 += params_copy.
strides[2])
325 in_3 += params_copy.
strides[3])
328 in_4 += params_copy.
strides[4])
330 *out_ptr++ = input_data[
Offset(input_shape, in_0, in_1, in_2, in_3, in_4)];
StridedSliceParams buildStridedSliceParams(const T *begin, const T *end, const T *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)