129 const auto end_mask = params.
end_mask;
132 const auto *strides = params.
strides;
135 const bool shrink_axis = shrink_axis_mask & (1 << axis);
136 int stop = stop_indices[axis];
144 stop = start_for_axis + 1;
148 if (end_mask & (1 << axis))
150 if (strides[axis] > 0)
154 stop = std::numeric_limits<int>::max();
159 stop = std::numeric_limits<int>::lowest();
164 const int axis_size = input_shape.
Dims(axis);
173 if (strides[axis] > 0)
176 stop =
Clamp(stop, 0, axis_size);
181 stop =
Clamp(stop, -1, axis_size - 1);
227 [[maybe_unused]] int32_t shape_size = 0;
229 for (uint32_t idx = 0; idx < rank; ++idx)
231 int32_t stride = op_params.
strides[idx];
245 int32_t dim_shape = std::ceil((
end -
begin) /
static_cast<float>(stride));
246 dim_shape = dim_shape < 0 ? 0 : dim_shape;
259 const T *input_data,
const Shape &unextended_output_shape, T *output_data)
264 bool optimize =
true;
266 for (
int idx = 0; idx < st_count - 1; idx++)
268 const int axis_size = unextended_input_shape.
Dims(idx);
269 const int start =
StartForAxis(op_params, unextended_input_shape, idx);
270 const int stop =
StopForAxis(op_params, unextended_input_shape, idx, start);
271 if ((axis_size != 1) && (start != 0 || stop != 0))
280 if (op_params.
strides[st_count - 1] == 1)
282 const int start =
StartForAxis(op_params, unextended_input_shape, st_count - 1);
283 const int end =
StopForAxis(op_params, unextended_input_shape, st_count - 1, start);
285 for (
int idx = 0; idx <
end - start; idx++)
287 output_data[idx] = input_data[idx + start];
296 const Shape input_shape = Shape::ExtendedShape(4, unextended_input_shape);
303 const int start_b =
StartForAxis(params_copy, input_shape, 0);
304 const int stop_b =
StopForAxis(params_copy, input_shape, 0, start_b);
305 const int start_h =
StartForAxis(params_copy, input_shape, 1);
306 const int stop_h =
StopForAxis(params_copy, input_shape, 1, start_h);
307 const int start_w =
StartForAxis(params_copy, input_shape, 2);
308 const int stop_w =
StopForAxis(params_copy, input_shape, 2, start_w);
309 const int start_d =
StartForAxis(params_copy, input_shape, 3);
310 const int stop_d =
StopForAxis(params_copy, input_shape, 3, start_d);
312 T *out_ptr = output_data;
314 in_b += params_copy.
strides[0])
317 in_h += params_copy.
strides[1])
320 in_w += params_copy.
strides[2])
323 in_d += params_copy.
strides[3])
325 *out_ptr++ = input_data[
Offset(input_shape, in_b, in_h, in_w, in_d)];
void StridedSlice(const StridedSliceParams &op_params, const Shape &unextended_input_shape, const T *input_data, const Shape &unextended_output_shape, T *output_data)
StridedSliceParams buildStridedSliceParams(const T *begin, const T *end, const T *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)