ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::shape_inference Namespace Reference

Namespaces

namespace  bcq
 

Data Structures

struct  StridedSliceParams
 

Typedefs

using Shapes = std::vector< ir::Shape >
 

Functions

ir::Shape inferArgMinMaxShape (const ir::Shape &input_shape, int axis, int rank)
 
ir::Shape inferBatchMatMulShape (const ir::Shape &lhs_shape, const ir::Shape &rhs_shape, const ir::operation::BatchMatMul::Param &param)
 
ir::Shape inferBCQFullyConnectedShape (const ir::Shape &in_shape, const ir::Shape &cluster_shape, const int32_t *cluster_buf)
 
ir::Shape inferBCQGatherShape (const ir::Shape &indices_shape, const ir::Shape &cluster_shape, const int32_t *cluster_buf, int rank, const ir::operation::BCQGather::Param &param)
 
ir::Shape inferBroadcastToShape (const ir::Shape shp_shape, const int32_t *shp_buf)
 
ir::Shape inferConcatShape (const Shapes &in_shapes, const ir::operation::Concat::Param &param)
 
ir::Shape inferConv2DShape (const ir::Shape &in_shape, const ir::Shape &ker_shape, const ir::operation::Conv2D::Param &param)
 
ir::Shape inferDepthwiseConv2DShape (const ir::Shape &in_shape, const ir::Shape &ker_shape, const ir::operation::DepthwiseConv2D::Param &param)
 
ir::Shape inferEltwiseShape (const ir::Shape &lhs_shape, const ir::Shape &rhs_shape)
 
ir::Shape inferExpandDimsShape (const ir::Shape &in_shape, int32_t axis)
 
template<typename T >
ir::Shape inferFillShape (const ir::Shape &fill_shape, const T *shape_buf)
 
ir::Shape inferFullyConnectedShape (const ir::Shape &in_shape, const ir::Shape &ker_shape)
 
ir::Shape inferGatherShape (const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis, int rank)
 
ir::Shape inferOnehotShape (const ir::Shape &input_shape, const int depth, int axis)
 
ir::Shape inferPackShape (const ir::Shape &input_shape, int axis, int rank, int num)
 
ir::Shape inferPadShape (const ir::Shape &in_shape, const int32_t *pad_buf, const size_t num_pads)
 
ir::Shape inferPoolShape (const ir::Shape &in_shape, const ir::operation::Pool2D::Param &param)
 
template<typename T >
ir::Shape inferRangeShape (T start_val, T limit_val, T delta_val)
 
ir::Shape inferReshapeShape (const ir::Shape &input_shape, const int32_t *shape_buf, const int32_t shape_num_elements)
 
ir::Shape inferReduceShape (const ir::Shape &input_shape, const std::vector< int > &axes, bool keep_dims)
 
template<float * >
ir::Shape inferRangeShape (float *start_val, float *limit_val, float *delta_val)
 
ir::Shape inferResizeBilinearShape (const ir::Shape &in_shape, const int32_t output_height, const int32_t output_width)
 
ir::Shape inferSelectShape (const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape, const ir::Shape &input_false_shape)
 
template<typename T >
ir::Shape inferSliceShape (const ir::Shape &input_shape, const T *begins_buf, const T *sizes_buf)
 
ir::Shape inferSpaceToBatchNDShape (const ir::Shape &input_shape, const ir::Shape &block_shape_shape, const ir::Shape &padding_shape, const int32_t *block_shape_buf, const int32_t *padding_buf)
 
ir::Shape inferSplitShape (const ir::Shape input_shape, int axis_value, int num_splits)
 
ir::Shape inferSqueezeShape (const ir::Shape &in_shape, const ir::operation::Squeeze::Param &param)
 
template<typename T >
StridedSliceParams buildStridedSliceParams (const T *begin, const T *end, const T *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
 
ir::Shape inferStridedSliceShape (const ir::Shape &input_shape, const StridedSliceParams &op_params, uint32_t rank)
 
ir::Shape inferTileShape (const ir::Shape &in_shape, const int32_t *multiplier_buf, const int32_t multiplier_size)
 
ir::Shape inferTransposeShape (const ir::Shape &in_shape, const int32_t *perm_buf, const int32_t rank)
 
ir::Shape inferUnpackShape (const ir::Shape &input_shape, int axis, int rank)
 
std::pair< int, int > calcConvLikeHeightAndWidth (const int in_h, const int in_w, const int ker_h, const int ker_w, const ir::Padding pad, const ir::Stride stride, const ir::Dilation dilation={1, 1})
 
template ir::Shape inferFillShape (const ir::Shape &fill_shape, const int32_t *shape_buf)
 
template ir::Shape inferFillShape (const ir::Shape &fill_shape, const int64_t *shape_buf)
 
template ir::Shape inferRangeShape (int start_val, int limit_val, int delta_val)
 
template ir::Shape inferRangeShape (float start_val, float limit_val, float delta_val)
 
template ir::Shape inferSliceShape (const ir::Shape &input_shape, const int32_t *begins_buf, const int32_t *sizes_buf)
 
template ir::Shape inferSliceShape (const ir::Shape &input_shape, const int64_t *begins_buf, const int64_t *sizes_buf)
 
template StridedSliceParams buildStridedSliceParams (const uint32_t *begin, const uint32_t *end, const uint32_t *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
 
int Clamp (const int v, const int lo, const int hi)
 
int StartForAxis (const StridedSliceParams &params, const ir::Shape &input_shape, int axis)
 
int StopForAxis (const StridedSliceParams &params, const ir::Shape &input_shape, int axis, int start_for_axis)
 

Typedef Documentation

◆ Shapes

using onert::shape_inference::Shapes = typedef std::vector<ir::Shape>

Definition at line 39 of file ShapeInference.h.

Function Documentation

◆ buildStridedSliceParams() [1/2]

template<typename T >
StridedSliceParams onert::shape_inference::buildStridedSliceParams ( const T *  begin,
const T *  end,
const T *  strides,
const uint32_t  begin_mask,
const uint32_t  end_mask,
const uint32_t  shrink_axis_mask,
const uint8_t  rank 
)

Definition at line 868 of file ShapeInference.cc.

871{
872 StridedSliceParams op_params;
873 op_params.start_indices_count = rank;
874 op_params.stop_indices_count = rank;
875 op_params.strides_count = rank;
876
877 for (int i = 0; i < op_params.strides_count; ++i)
878 {
879 op_params.start_indices[i] = begin[i];
880 op_params.stop_indices[i] = end[i];
881 op_params.strides[i] = strides[i];
882
883 assert(op_params.strides[i] != 0);
884 }
885
886 op_params.begin_mask = begin_mask;
887 op_params.ellipsis_mask = 0; // NYI
888 op_params.end_mask = end_mask;
889 op_params.new_axis_mask = 0; // NYI
890 op_params.shrink_axis_mask = shrink_axis_mask;
891
892 assert(sizeof(op_params.begin_mask) * 4 >= rank);
893
894 return op_params;
895}
int32_t begin[5]
Definition Slice.cpp:33

References begin, onert::shape_inference::StridedSliceParams::begin_mask, onert::shape_inference::StridedSliceParams::ellipsis_mask, onert::shape_inference::StridedSliceParams::end_mask, onert::shape_inference::StridedSliceParams::new_axis_mask, onert::shape_inference::StridedSliceParams::shrink_axis_mask, onert::shape_inference::StridedSliceParams::start_indices, onert::shape_inference::StridedSliceParams::start_indices_count, onert::shape_inference::StridedSliceParams::stop_indices, onert::shape_inference::StridedSliceParams::stop_indices_count, onert::shape_inference::StridedSliceParams::strides, and onert::shape_inference::StridedSliceParams::strides_count.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ buildStridedSliceParams() [2/2]

template StridedSliceParams onert::shape_inference::buildStridedSliceParams ( const uint32_t *  begin,
const uint32_t *  end,
const uint32_t *  strides,
const uint32_t  begin_mask,
const uint32_t  end_mask,
const uint32_t  shrink_axis_mask,
const uint8_t  rank 
)

◆ calcConvLikeHeightAndWidth()

std::pair< int, int > onert::shape_inference::calcConvLikeHeightAndWidth ( const int  in_h,
const int  in_w,
const int  ker_h,
const int  ker_w,
const ir::Padding  pad,
const ir::Stride  stride,
const ir::Dilation  dilation = {1, 1} 
)

Definition at line 93 of file ShapeInference.cc.

96 {1, 1})
97{
98 int32_t out_h = 0, out_w = 0;
99 int32_t effective_filter_w_size = (ker_w - 1) * dilation.width_factor + 1;
100 int32_t effective_filter_h_size = (ker_h - 1) * dilation.height_factor + 1;
101 switch (pad.type)
102 {
103 case ir::PaddingType::SAME:
104 out_h = ceil_div(in_h, stride.vertical);
105 out_w = ceil_div(in_w, stride.horizontal);
106 break;
107 case ir::PaddingType::VALID:
108 out_h = ceil_div(in_h - effective_filter_h_size + 1, stride.vertical);
109 out_w = ceil_div(in_w - effective_filter_w_size + 1, stride.horizontal);
110 break;
111 case ir::PaddingType::EXPLICIT:
112 out_h =
113 (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1;
114 out_w =
115 (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal + 1;
116 break;
117 default:
118 assert(false);
119 }
120
121 return {out_h, out_w};
122}
PaddingType type
Definition Padding.h:61
ExplicitPadding param
Definition Padding.h:62

Referenced by inferConv2DShape(), inferDepthwiseConv2DShape(), and inferPoolShape().

◆ Clamp()

int onert::shape_inference::Clamp ( const int  v,
const int  lo,
const int  hi 
)

Definition at line 903 of file ShapeInference.cc.

904{
905 assert(!(hi < lo));
906 if (hi < v)
907 return hi;
908 if (v < lo)
909 return lo;
910 return v;
911}

Referenced by StartForAxis(), and StopForAxis().

◆ inferArgMinMaxShape()

ir::Shape onert::shape_inference::inferArgMinMaxShape ( const ir::Shape input_shape,
int  axis,
int  rank 
)

Definition at line 129 of file ShapeInference.cc.

130{
131 if (axis < 0 || axis >= rank)
132 {
133 throw std::runtime_error("ArgMinMax shape inference: Wrong axis value " + std::to_string(axis));
134 }
135
136 ir::Shape out_shape;
137 for (int idx = 0; idx < rank; ++idx)
138 {
139 if (idx != axis)
140 {
141 int32_t input_dim = input_shape.dim(idx);
142 out_shape.append(input_dim);
143 }
144 }
145
146 return out_shape;
147}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBatchMatMulShape()

ir::Shape onert::shape_inference::inferBatchMatMulShape ( const ir::Shape lhs_shape,
const ir::Shape rhs_shape,
const ir::operation::BatchMatMul::Param param 
)

Definition at line 231 of file ShapeInference.cc.

233{
234 bool adj_x = param.adj_x;
235 bool adj_y = param.adj_y;
237
238 int output_rank = std::max(lhs_shape.rank(), rhs_shape.rank());
239
240 // Extend lhs and rhs shape
241 ir::Shape extended_lhs_shape(lhs_shape);
242 ir::Shape extended_rhs_shape(rhs_shape);
243 extended_lhs_shape.extendRank(output_rank);
244 extended_rhs_shape.extendRank(output_rank);
245
246 for (int i = 0; i < output_rank - 2; i++)
247 {
248 const int lhs_dim = extended_lhs_shape.dim(i);
249 const int rhs_dim = extended_rhs_shape.dim(i);
250 int broadcast_dim = lhs_dim;
251 if (lhs_dim != rhs_dim)
252 {
253 if (lhs_dim == 1)
254 {
255 broadcast_dim = rhs_dim;
256 }
257 else if (rhs_dim != 1)
258 {
259 throw std::runtime_error{"BatchMatMul shape inference: invalid brodcasting input shape"};
260 }
261 }
262
263 output_shape.append(broadcast_dim);
264 }
265
266 // Fill in the matmul dimensions.
267 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
268 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
269
270 output_shape.append(extended_lhs_shape.dim(lhs_rows_index));
271 output_shape.append(extended_rhs_shape.dim(rhs_cols_index));
272
273 return output_shape;
274}
const luci_interpreter::RuntimeShape output_shape

References onert::ir::operation::BatchMatMul::Param::adj_x, onert::ir::operation::BatchMatMul::Param::adj_y, and output_shape.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBCQFullyConnectedShape()

ir::Shape onert::shape_inference::inferBCQFullyConnectedShape ( const ir::Shape in_shape,
const ir::Shape cluster_shape,
const int32_t *  cluster_buf 
)

Definition at line 410 of file ShapeInference.cc.

412{
413 assert(cluster_shape.rank() == 2);
414 assert(cluster_shape.dim(1) == 2);
415
416 const auto input_size = in_shape.dim(1);
417 const auto output_size = bcq::getOutputSize(cluster_shape, cluster_buf);
418
419 return {ir::Shape({output_size, input_size})};
420}

References onert::shape_inference::bcq::getOutputSize().

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBCQGatherShape()

ir::Shape onert::shape_inference::inferBCQGatherShape ( const ir::Shape indices_shape,
const ir::Shape cluster_shape,
const int32_t *  cluster_buf,
int  rank,
const ir::operation::BCQGather::Param param 
)

Definition at line 422 of file ShapeInference.cc.

425{
426 ir::Shape out_shape;
427 ir::Shape in_original_shape;
428
429 assert(cluster_shape.rank() == 2);
430 assert(cluster_shape.dim(1) == 2);
431
432 auto hidden_size = param.input_hidden_size;
433 auto axis = param.axis;
434
435 in_original_shape.append(bcq::getOutputSize(cluster_shape, cluster_buf));
436 in_original_shape.append(hidden_size);
437
438 const int indices_rank = indices_shape.rank();
439 for (int idx = 0; idx < rank; ++idx)
440 {
441 if (idx == (int)axis)
442 {
443 for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
444 {
445 out_shape.append(indices_shape.dim(indices_idx));
446 }
447 }
448 else
449 {
450 out_shape.append(in_original_shape.dim(idx));
451 }
452 }
453
454 return out_shape;
455}

References onert::ir::operation::BCQGather::Param::axis, onert::shape_inference::bcq::getOutputSize(), and onert::ir::operation::BCQGather::Param::input_hidden_size.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBroadcastToShape()

ir::Shape onert::shape_inference::inferBroadcastToShape ( const ir::Shape  shp_shape,
const int32_t *  shp_buf 
)

Definition at line 280 of file ShapeInference.cc.

281{
282
283 const int num_elements = shp_shape.num_elements();
284
285 assert(num_elements != 0);
286 assert(shp_buf);
287
288 ir::Shape new_shape(num_elements);
289
290 for (int i = 0; i < num_elements; ++i)
291 {
292 assert(shp_buf[i] != 0); // It shouldn't be 0.
293 new_shape.dim(i) = shp_buf[i];
294 }
295
296 return new_shape;
297}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferConcatShape()

ir::Shape onert::shape_inference::inferConcatShape ( const Shapes in_shapes,
const ir::operation::Concat::Param param 
)

Definition at line 299 of file ShapeInference.cc.

300{
301 const int32_t concat_axis = param.axis >= 0 ? param.axis : in_shapes[0].rank() + param.axis;
302 const auto &first_in_shape = in_shapes[0];
303
304 // Check that all shapes are equal except for concat axis dimension
305 for (const auto &in_shape : in_shapes)
306 {
307 if (in_shape.rank() != first_in_shape.rank())
308 throw std::runtime_error("Rank in all input tensors should be same");
309
310 for (int64_t dim_idx = 0; dim_idx < in_shape.rank(); ++dim_idx)
311 if (!(dim_idx == concat_axis || in_shape.dim(dim_idx) == first_in_shape.dim(dim_idx)))
312 throw std::runtime_error("All tensor should have same dimension "
313 "except dimension on passed axis");
314 }
315
316 // Calculate output shape
317 ir::Shape out_shape(first_in_shape);
318 out_shape.dim(concat_axis) = 0;
319 for (const auto &in_shape : in_shapes)
320 out_shape.dim(concat_axis) += in_shape.dim(concat_axis);
321 return out_shape;
322}

References onert::ir::operation::Concat::Param::axis.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferConv2DShape()

ir::Shape onert::shape_inference::inferConv2DShape ( const ir::Shape in_shape,
const ir::Shape ker_shape,
const ir::operation::Conv2D::Param param 
)

Definition at line 324 of file ShapeInference.cc.

326{
327 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
328 throw std::runtime_error{"Conv2D: stride values must be positive"};
329
330 auto ifm_shape = in_shape.asFeature();
331
332 // Kernel format is [depth_out, kernel_height, kernel_width, depth_in]
333 auto kf_shape = ker_shape.asFeature();
334 assert(ifm_shape.C == kf_shape.C);
335
336 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(
337 ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W, param.padding, param.stride, param.dilation);
338
339 return ir::Shape{ifm_shape.N, out_h, out_w, kf_shape.N};
340}

References calcConvLikeHeightAndWidth(), onert::ir::operation::Conv2D::Param::dilation, onert::ir::Stride::horizontal, onert::ir::operation::Conv2D::Param::padding, onert::ir::operation::Conv2D::Param::stride, and onert::ir::Stride::vertical.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferDepthwiseConv2DShape()

ir::Shape onert::shape_inference::inferDepthwiseConv2DShape ( const ir::Shape in_shape,
const ir::Shape ker_shape,
const ir::operation::DepthwiseConv2D::Param param 
)

Definition at line 342 of file ShapeInference.cc.

344{
345 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
346 throw std::runtime_error{"DepthwiseConv2D: stride values must be positive"};
347
348 auto ifm_shape = in_shape.asFeature();
349
350 // Kernel format is [1, kernel_height, kernel_width, depth_out]
351 auto kf_shape = ker_shape.asFeature();
352 assert(kf_shape.C == static_cast<int32_t>(ifm_shape.C * param.multiplier));
353 assert(kf_shape.N == 1);
354
355 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(
356 ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W, param.padding, param.stride, param.dilation);
357
358 return ir::Shape{ifm_shape.N, out_h, out_w, kf_shape.C};
359}

References calcConvLikeHeightAndWidth(), onert::ir::operation::DepthwiseConv2D::Param::dilation, onert::ir::Stride::horizontal, onert::ir::operation::DepthwiseConv2D::Param::multiplier, onert::ir::operation::DepthwiseConv2D::Param::padding, onert::ir::operation::DepthwiseConv2D::Param::stride, and onert::ir::Stride::vertical.

◆ inferEltwiseShape()

ir::Shape onert::shape_inference::inferEltwiseShape ( const ir::Shape lhs_shape,
const ir::Shape rhs_shape 
)

Definition at line 124 of file ShapeInference.cc.

125{
126 return broadcastShapes(lhs_shape, rhs_shape);
127}

◆ inferExpandDimsShape()

ir::Shape onert::shape_inference::inferExpandDimsShape ( const ir::Shape in_shape,
int32_t  axis 
)

Definition at line 361 of file ShapeInference.cc.

362{
363 ir::Shape out_shape(in_shape.rank() + 1);
364
365 axis = ((axis >= 0) ? axis : /* when axis < 0 */ (out_shape.rank() + axis));
366 if (!(0 <= axis && axis <= in_shape.rank()))
367 throw std::runtime_error("axis of dim is out of range");
368
369 for (int x = 0, out_x = 0; out_x < out_shape.rank(); ++out_x)
370 {
371 if (out_x == axis)
372 out_shape.dim(out_x) = 1;
373 else
374 out_shape.dim(out_x) = in_shape.dim(x++);
375 }
376
377 return out_shape;
378}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferFillShape() [1/3]

template ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const int32_t *  shape_buf 
)

◆ inferFillShape() [2/3]

template ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const int64_t *  shape_buf 
)

◆ inferFillShape() [3/3]

template<typename T >
ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const T *  shape_buf 
)

Definition at line 380 of file ShapeInference.cc.

381{
382 ir::Shape out_shape(fill_shape.dim(0));
383
384 for (int out_x = 0; out_x < out_shape.rank(); ++out_x)
385 {
386 out_shape.dim(out_x) = static_cast<int32_t>(shape_buf[out_x]);
387 }
388
389 return out_shape;
390}

◆ inferFullyConnectedShape()

ir::Shape onert::shape_inference::inferFullyConnectedShape ( const ir::Shape in_shape,
const ir::Shape ker_shape 
)

Definition at line 396 of file ShapeInference.cc.

397{
398 assert(in_shape.rank() >= 2);
399 assert(ker_shape.rank() == 2);
400
401 const auto input_size_with_batch = in_shape.num_elements();
402 const auto num_units = ker_shape.dim(0);
403 const auto input_size = ker_shape.dim(1);
404 const auto batch_size = input_size_with_batch / input_size;
405 assert(input_size_with_batch % input_size == 0);
406
407 return {ir::Shape({static_cast<int32_t>(batch_size), num_units})};
408}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferGatherShape()

ir::Shape onert::shape_inference::inferGatherShape ( const ir::Shape input_shape,
const ir::Shape indices_shape,
int  axis,
int  rank 
)

Definition at line 457 of file ShapeInference.cc.

459{
460 ir::Shape out_shape;
461
462 const int indices_rank = indices_shape.rank();
463
464 for (int idx = 0; idx < rank; ++idx)
465 {
466 if (idx == axis)
467 {
468 for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
469 {
470 out_shape.append(indices_shape.dim(indices_idx));
471 }
472 }
473 else
474 {
475 out_shape.append(input_shape.dim(idx));
476 }
477 }
478
479 return out_shape;
480}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferOnehotShape()

ir::Shape onert::shape_inference::inferOnehotShape ( const ir::Shape input_shape,
const int  depth,
int  axis 
)

Definition at line 482 of file ShapeInference.cc.

483{
484 assert(depth >= 0);
485 const auto rank = input_shape.rank() + 1;
486 ir::Shape newShape(rank);
487
488 axis = (axis == -1) ? (rank - 1) : axis;
489
490 for (int i = 0; i < rank; ++i)
491 {
492 if (i < axis)
493 {
494 newShape.dim(i) = input_shape.dim(i);
495 }
496 else if (i == axis)
497 {
498 newShape.dim(i) = depth;
499 }
500 else
501 {
502 newShape.dim(i) = input_shape.dim(i - 1);
503 }
504 }
505
506 return newShape;
507}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPackShape()

ir::Shape onert::shape_inference::inferPackShape ( const ir::Shape input_shape,
int  axis,
int  rank,
int  num 
)

Definition at line 509 of file ShapeInference.cc.

510{
511 ir::Shape out_shape;
512 int in_idx = 0;
513
514 for (int out_idx = 0; out_idx < rank; ++out_idx)
515 {
516 if (out_idx == axis)
517 {
518 out_shape.append(num);
519 }
520 else
521 {
522 out_shape.append(input_shape.dim(in_idx++));
523 }
524 }
525
526 return out_shape;
527}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPadShape()

ir::Shape onert::shape_inference::inferPadShape ( const ir::Shape in_shape,
const int32_t *  pad_buf,
const size_t  num_pads 
)

Definition at line 529 of file ShapeInference.cc.

530{
531 assert(num_pads % 2 == 0);
532 const int32_t rank = num_pads / 2;
533
534 ir::Shape ret(rank);
535 for (int32_t i = 0; i < rank; ++i)
536 {
537 const auto before_padding = pad_buf[i * 2];
538 const auto after_padding = pad_buf[i * 2 + 1];
539
540 ret.dim(i) = in_shape.dim(i) + before_padding + after_padding;
541 }
542
543 return ret;
544}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPoolShape()

ir::Shape onert::shape_inference::inferPoolShape ( const ir::Shape in_shape,
const ir::operation::Pool2D::Param param 
)

Definition at line 546 of file ShapeInference.cc.

547{
548 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
549 throw std::runtime_error{"Pool2D: stride values must be positive"};
550
551 auto ifm_shape = in_shape.asFeature();
552 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh,
553 param.kw, param.padding, param.stride);
554 // Pooling don't change number of channels and batch size
555 return ir::Shape{ifm_shape.N, out_h, out_w, ifm_shape.C};
556}

References calcConvLikeHeightAndWidth(), onert::ir::Stride::horizontal, onert::ir::operation::Pool2D::Param::kh, onert::ir::operation::Pool2D::Param::kw, onert::ir::operation::Pool2D::Param::padding, onert::ir::operation::Pool2D::Param::stride, and onert::ir::Stride::vertical.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferRangeShape() [1/4]

template<float * >
ir::Shape onert::shape_inference::inferRangeShape ( float *  start_val,
float *  limit_val,
float *  delta_val 
)

◆ inferRangeShape() [2/4]

template ir::Shape onert::shape_inference::inferRangeShape ( float  start_val,
float  limit_val,
float  delta_val 
)

◆ inferRangeShape() [3/4]

template ir::Shape onert::shape_inference::inferRangeShape ( int  start_val,
int  limit_val,
int  delta_val 
)

◆ inferRangeShape() [4/4]

template<typename T >
ir::Shape onert::shape_inference::inferRangeShape ( start_val,
limit_val,
delta_val 
)

Definition at line 583 of file ShapeInference.cc.

584{
585 ir::Shape out_shape(static_cast<int>(1));
586
587 out_shape.dim(0) =
588 (std::is_integral<T>::value
589 ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
590 : std::ceil(std::abs((start_val - limit_val) / delta_val)));
591 return out_shape;
592}

◆ inferReduceShape()

ir::Shape onert::shape_inference::inferReduceShape ( const ir::Shape input_shape,
const std::vector< int > &  axes,
bool  keep_dims 
)

Definition at line 149 of file ShapeInference.cc.

151{
152 int num_axis = axes.size();
153 int input_num_dims = input_shape.rank();
154 if (input_num_dims == 0)
155 {
156 ir::Shape out_shape(0);
157 return out_shape;
158 }
159 if (keep_dims)
160 {
161 ir::Shape out_shape;
162 for (int idx = 0; idx < input_num_dims; ++idx)
163 {
164 bool is_axis = false;
165 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
166 {
167 if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
168 {
169 is_axis = true;
170 break;
171 }
172 }
173 if (is_axis)
174 {
175 out_shape.append(1);
176 }
177 else
178 {
179 out_shape.append(input_shape.dim(idx));
180 }
181 }
182 return out_shape;
183 }
184 else
185 {
186 // Calculates size of reducing axis.
187 for (int i = 0; i < num_axis; ++i)
188 {
189 int current = axes[i];
190 if (!(-input_num_dims <= current && current < input_num_dims))
191 throw std::runtime_error{"Invalid dim value " + std::to_string(current)};
192 if (current < 0)
193 {
194 current += input_num_dims;
195 }
196 for (int j = 0; j < i; ++j)
197 {
198 int previous = axes[j];
199 if (previous < 0)
200 {
201 previous += input_num_dims;
202 }
203 if (current == previous)
204 {
205 break;
206 }
207 }
208 }
209 // Determines output dimensions.
210 ir::Shape out_shape;
211 for (int idx = 0; idx < input_num_dims; ++idx)
212 {
213 bool is_axis = false;
214 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
215 {
216 if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
217 {
218 is_axis = true;
219 break;
220 }
221 }
222 if (!is_axis)
223 {
224 out_shape.append(input_shape.dim(idx));
225 }
226 }
227 return out_shape;
228 }
229}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferReshapeShape()

ir::Shape onert::shape_inference::inferReshapeShape ( const ir::Shape input_shape,
const int32_t *  shape_buf,
const int32_t  shape_num_elements 
)

Definition at line 598 of file ShapeInference.cc.

600{
601 ir::Shape ret(shape_num_elements);
602 int32_t flatten_dim = ir::Shape::kUnspecifiedDim;
603 auto total_num_elements = input_shape.num_elements();
604 for (int32_t i = 0; i < shape_num_elements; ++i)
605 {
606 if (shape_buf[i] < 0)
607 {
608 if (flatten_dim != ir::Shape::kUnspecifiedDim)
609 throw std::runtime_error("Reshape: 2nd param has special dim(for flatten) more than twice");
610 flatten_dim = i;
611 ret.dim(i) = 1;
612 }
613 else
614 {
615 ret.dim(i) = shape_buf[i];
616 }
617 }
618 if (flatten_dim != ir::Shape::kUnspecifiedDim)
619 ret.dim(flatten_dim) = total_num_elements / ret.num_elements();
620
621 // Check reshapable
622 if (total_num_elements != static_cast<size_t>(ret.num_elements()))
623 {
624 // Multi batch case
625 // TODO Handle multi batch case more precisely on runtime level
626 if ((ret.dim(0) == 1) &&
627 (total_num_elements == static_cast<size_t>(ret.num_elements() * input_shape.dim(0))))
628 ret.dim(0) = input_shape.dim(0);
629 else
630 throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
631 }
632
633 return ret;
634}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferResizeBilinearShape()

ir::Shape onert::shape_inference::inferResizeBilinearShape ( const ir::Shape in_shape,
const int32_t  output_height,
const int32_t  output_width 
)

Definition at line 558 of file ShapeInference.cc.

560{
561 assert(in_shape.rank() == 4);
562 if (output_height < 0)
563 {
564 throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_height = " +
565 std::to_string(output_height)};
566 }
567 if (output_width < 0)
568 {
569 throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_width = " +
570 std::to_string(output_width)};
571 }
572
573 ir::Shape ret(in_shape.rank());
574
575 ret.dim(0) = in_shape.dim(0);
576 ret.dim(1) = output_height;
577 ret.dim(2) = output_width;
578 ret.dim(3) = in_shape.dim(3);
579
580 return ret;
581}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSelectShape()

ir::Shape onert::shape_inference::inferSelectShape ( const ir::Shape input_cond_shape,
const ir::Shape input_true_shape,
const ir::Shape input_false_shape 
)

Definition at line 636 of file ShapeInference.cc.

638{
639 auto haveSameShapes = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
640 const ir::Shape &input_false_shape) {
641 if ((input_cond_shape.rank() != input_true_shape.rank()) ||
642 input_cond_shape.rank() != input_false_shape.rank())
643 {
644 return false;
645 }
646
647 int rank = input_cond_shape.rank();
648 for (int i = 0; i < rank; ++i)
649 {
650 if (input_cond_shape.dim(i) != input_true_shape.dim(i) ||
651 input_cond_shape.dim(i) != input_false_shape.dim(i))
652 {
653 return false;
654 }
655 }
656
657 return true;
658 };
659
660 auto calculateShape = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
661 const ir::Shape &input_false_shape, ir::Shape &new_shape) {
662 ir::Shape cond_shape = input_cond_shape;
663 ir::Shape true_shape = input_true_shape;
664 ir::Shape false_shape = input_false_shape;
665 int most_rank =
666 (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
667 ? cond_shape.rank()
668 : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
669
670 ir::Shape calculate_shape(most_rank);
671
672 cond_shape.extendRank(most_rank);
673 true_shape.extendRank(most_rank);
674 false_shape.extendRank(most_rank);
675
676 for (int i = 0; i < most_rank; ++i)
677 {
678 calculate_shape.dim(i) =
679 (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
680 ? cond_shape.dim(i)
681 : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
682
683 if ((cond_shape.dim(i) != calculate_shape.dim(i) && cond_shape.dim(i) != 1) ||
684 (true_shape.dim(i) != calculate_shape.dim(i) && true_shape.dim(i) != 1) ||
685 (false_shape.dim(i) != calculate_shape.dim(i) && false_shape.dim(i) != 1))
686 {
687 return false;
688 }
689 }
690
691 new_shape = calculate_shape;
692
693 return true;
694 };
695
696 bool havesame = haveSameShapes(input_cond_shape, input_true_shape, input_false_shape);
697 if (havesame)
698 {
699 return input_cond_shape;
700 }
701
702 ir::Shape new_shape;
703 bool possible = calculateShape(input_cond_shape, input_true_shape, input_false_shape, new_shape);
704
705 if (!possible)
706 {
707 throw std::runtime_error("Broadcasting is not possible.");
708 }
709
710 return new_shape;
711}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSliceShape() [1/3]

template ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const int32_t *  begins_buf,
const int32_t *  sizes_buf 
)

◆ inferSliceShape() [2/3]

template ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const int64_t *  begins_buf,
const int64_t *  sizes_buf 
)

◆ inferSliceShape() [3/3]

template<typename T >
ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const T *  begins_buf,
const T *  sizes_buf 
)

Definition at line 714 of file ShapeInference.cc.

715{
716 const uint32_t rank = input_shape.rank();
717 ir::Shape out_shape(rank);
718
719 for (uint32_t idx = 0; idx < rank; ++idx)
720 {
721 const auto input_dim = input_shape.dim(idx);
722
723 // begin is zero-based
724 auto begin = begins_buf[idx];
725 if (begin < 0)
726 throw std::runtime_error("shape inference Slice: Invalid begin.");
727
728 // size is one-based
729 auto size = sizes_buf[idx];
730 if (size < -1)
731 throw std::runtime_error("shape inference Slice: Invalid size.");
732
733 if (size == -1)
734 {
735 size = input_dim - begin;
736 }
737 else
738 {
739 if (input_dim < static_cast<int32_t>(begin + size))
740 throw std::runtime_error("shape inference Slice: Invalid begin and size.");
741 }
742 out_shape.dim(idx) = static_cast<int32_t>(size);
743 }
744
745 return out_shape;
746}
int32_t size[5]
Definition Slice.cpp:35

References begin, and size.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSpaceToBatchNDShape()

ir::Shape onert::shape_inference::inferSpaceToBatchNDShape ( const ir::Shape input_shape,
const ir::Shape block_shape_shape,
const ir::Shape padding_shape,
const int32_t *  block_shape_buf,
const int32_t *  padding_buf 
)

Definition at line 753 of file ShapeInference.cc.

757{
758 const uint32_t rank = input_shape.rank();
759 ir::Shape out_shape(rank);
760
761 // Currently, only 4D NHWC input/output op_context are supported.
762 // The 4D array need to have exactly 2 spatial dimensions.
763 // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
764 [[maybe_unused]] const int32_t kInputDimensionNum = 4;
765 [[maybe_unused]] const int32_t kBlockSizeDimensionNum = 1;
766 const int32_t kSpatialDimensionNum = 2;
767
768 assert(block_shape_shape.rank() == kBlockSizeDimensionNum);
769 assert(block_shape_shape.dim(0) == kSpatialDimensionNum);
770 assert(padding_shape.dim(0) == kSpatialDimensionNum);
771 assert(padding_shape.dim(1) == 2); // fixed, meaning left/right padding for each element
772 assert(padding_shape.rank() == 2); // fixed, meaning dimension(dim 0) and padding length(dim 1)
773
774 // Ensures the input height and width (with padding) is a multiple of block
775 // shape height and width.
776 for (int dim = 0; dim < kSpatialDimensionNum; ++dim)
777 {
778 int final_dim_size =
779 (input_shape.dim(dim + 1) + padding_buf[dim * 2] + padding_buf[dim * 2 + 1]);
780
781 assert(final_dim_size % block_shape_buf[dim] == 0);
782
783 out_shape.dim(dim + 1) = final_dim_size / block_shape_buf[dim];
784 }
785
786 const int output_batch_size = input_shape.dim(0) * block_shape_buf[0] * block_shape_buf[1];
787 const int output_channel_size = input_shape.dim(3);
788
789 out_shape.dim(0) = output_batch_size;
790 out_shape.dim(3) = output_channel_size;
791
792 return out_shape;
793}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSplitShape()

ir::Shape onert::shape_inference::inferSplitShape ( const ir::Shape  input_shape,
int  axis_value,
int  num_splits 
)

Definition at line 795 of file ShapeInference.cc.

796{
797 ir::Shape newShape(input_shape);
798
799 assert(axis_value >= 0);
800 assert(axis_value < input_shape.rank());
801
802 const int input_size = input_shape.dim(axis_value);
803 assert(input_size % num_splits == 0);
804 const int slice_size = input_size / num_splits;
805
806 newShape.dim(axis_value) = slice_size;
807
808 return newShape;
809}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSqueezeShape()

ir::Shape onert::shape_inference::inferSqueezeShape ( const ir::Shape in_shape,
const ir::operation::Squeeze::Param param 
)

Definition at line 811 of file ShapeInference.cc.

812{
813 const int ndims = param.ndim;
814 const int *squeeze_dims = param.dims;
815 bool should_squeeze[8] = {false};
816 int num_squeezed_dims = 0;
817 int shape_rank = in_shape.rank();
818 if (ndims == 0)
819 {
820 for (int idx = 0; idx < shape_rank; ++idx)
821 {
822 if (in_shape.dim(idx) == 1)
823 {
824 should_squeeze[idx] = true;
825 ++num_squeezed_dims;
826 }
827 }
828 }
829 else
830 {
831 for (int idx = 0; idx < ndims; ++idx)
832 {
833 int current = squeeze_dims[idx];
834 if (current < 0)
835 {
836 current += shape_rank;
837 }
838
839 if (!(current >= 0 && current < shape_rank && in_shape.dim(current) == 1))
840 {
841 throw std::runtime_error(
842 "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
843 }
844
845 if (!should_squeeze[current])
846 {
847 ++num_squeezed_dims;
848 }
849 should_squeeze[current] = true;
850 }
851 }
852
853 // Set output shape.
854 ir::Shape out_shape(shape_rank - num_squeezed_dims);
855 for (int in_idx = 0, out_idx = 0; in_idx < shape_rank; ++in_idx)
856 {
857 if (!should_squeeze[in_idx])
858 {
859 out_shape.dim(out_idx++) = in_shape.dim(in_idx);
860 }
861 }
862
863 return out_shape;
864}

References onert::ir::operation::Squeeze::Param::dims, and onert::ir::operation::Squeeze::Param::ndim.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferStridedSliceShape()

ir::Shape onert::shape_inference::inferStridedSliceShape ( const ir::Shape input_shape,
const StridedSliceParams op_params,
uint32_t  rank 
)

Definition at line 1018 of file ShapeInference.cc.

1020{
1021 ir::Shape out_shape;
1022
1023 for (uint32_t idx = 0; idx < rank; ++idx)
1024 {
1025 int32_t stride = op_params.strides[idx];
1026 int32_t begin = StartForAxis(op_params, input_shape, idx);
1027 int32_t end = StopForAxis(op_params, input_shape, idx, begin);
1028
1029 // When shrinking an axis, the end position does not matter (and can be
1030 // incorrect when negative indexing is used, see Issue #19260). Always use
1031 // begin + 1 to generate a length 1 slice, since begin has
1032 // already been adjusted for negative indices by StartForAxis.
1033 const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
1034 if (shrink_axis)
1035 {
1036 end = begin + 1;
1037 }
1038
1039 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
1040 dim_shape = dim_shape < 0 ? 0 : dim_shape;
1041 if (!shrink_axis)
1042 {
1043 out_shape.append(dim_shape);
1044 }
1045 }
1046
1047 return out_shape;
1048}

References begin, onert::shape_inference::StridedSliceParams::shrink_axis_mask, StartForAxis(), StopForAxis(), and onert::shape_inference::StridedSliceParams::strides.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferTileShape()

ir::Shape onert::shape_inference::inferTileShape ( const ir::Shape in_shape,
const int32_t *  multiplier_buf,
const int32_t  multiplier_size 
)

Definition at line 1050 of file ShapeInference.cc.

1052{
1053 if (multiplier_size != in_shape.rank())
1054 {
1055 throw std::runtime_error(
1056 "inferTileShape failed, input rank: " + std::to_string(in_shape.rank()) +
1057 ", bad multipliers size: " + std::to_string(multiplier_size) + "");
1058 }
1059 ir::Shape new_Shape(in_shape.rank());
1060
1061 for (int i = 0; i < in_shape.rank(); ++i)
1062 {
1063 assert(multiplier_buf[i]); // multiplier_buf[i] shuld not be 0.
1064 new_Shape.dim(i) = in_shape.dim(i) * multiplier_buf[i];
1065 }
1066 return new_Shape;
1067}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferTransposeShape()

ir::Shape onert::shape_inference::inferTransposeShape ( const ir::Shape in_shape,
const int32_t *  perm_buf,
const int32_t  rank 
)

Definition at line 1069 of file ShapeInference.cc.

1071{
1072 const auto rank = in_shape.rank();
1073 if (perm_size > rank)
1074 {
1075 throw std::runtime_error("inferTransposeShape failed, bad permutation size: " +
1076 std::to_string(perm_size));
1077 }
1078
1079 const int32_t *perm_data = perm_buf;
1080 std::vector<int32_t> regular_perm_vec;
1081 if (perm_size == 0)
1082 {
1083 // perm_data will be set to (n-1...0)
1084 regular_perm_vec.resize(rank);
1085 std::iota(regular_perm_vec.begin(), regular_perm_vec.end(), 0);
1086 std::reverse(regular_perm_vec.begin(), regular_perm_vec.end());
1087 perm_data = regular_perm_vec.data();
1088 }
1089 else
1090 {
1091 assert(rank == perm_size);
1092 }
1093
1094 ir::Shape out_shape(rank);
1095 std::vector<bool> visit_perms(rank, false);
1096 for (int idx = 0; idx < rank; idx++)
1097 {
1098 const auto perm_val = perm_data[idx];
1099 // Check invalid permutation value
1100 if (perm_val < 0 || perm_val >= rank)
1101 {
1102 throw std::runtime_error("inferTransposeShape failed, bad permutation value: " +
1103 std::to_string(perm_val));
1104 }
1105
1106 // Check duplicated permutation value
1107 if (visit_perms.at(perm_val))
1108 {
1109 throw std::runtime_error("inferTransposeShape failed, duplicated permutation value: " +
1110 std::to_string(perm_val));
1111 }
1112 visit_perms.at(perm_val) = true;
1113
1114 out_shape.dim(idx) = in_shape.dim(perm_val);
1115 }
1116 return out_shape;
1117}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferUnpackShape()

ir::Shape onert::shape_inference::inferUnpackShape ( const ir::Shape input_shape,
int  axis,
int  rank 
)

Definition at line 1119 of file ShapeInference.cc.

1120{
1121 ir::Shape out_shape;
1122
1123 for (int out_idx = 0; out_idx < rank; out_idx++)
1124 {
1125 if (out_idx != axis)
1126 {
1127 out_shape.append(input_shape.dim(out_idx));
1128 }
1129 }
1130
1131 return out_shape;
1132}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ StartForAxis()

int onert::shape_inference::StartForAxis ( const StridedSliceParams params,
const ir::Shape input_shape,
int  axis 
)

Definition at line 913 of file ShapeInference.cc.

914{
915 const auto begin_mask = params.begin_mask;
916 const auto *start_indices = params.start_indices;
917 const auto *strides = params.strides;
918 // Begin with the specified index.
919 int start = start_indices[axis];
920
921 // begin_mask override
922 if (begin_mask & 1 << axis)
923 {
924 if (strides[axis] > 0)
925 {
926 // Forward iteration - use the first element. These values will get
927 // clamped below (Note: We could have set them to 0 and axis_size-1, but
928 // use lowest() and max() to maintain symmetry with StopForAxis())
929 start = std::numeric_limits<int>::lowest();
930 }
931 else
932 {
933 // Backward iteration - use the last element.
934 start = std::numeric_limits<int>::max();
935 }
936 }
937
938 // Handle negative indices
939 int axis_size = input_shape.dim(axis);
940 if (start < 0)
941 {
942 start += axis_size;
943 }
944
945 // Clamping
946 start = Clamp(start, 0, axis_size - 1);
947
948 return start;
949}
int Clamp(const int v, const int lo, const int hi)

References onert::shape_inference::StridedSliceParams::begin_mask, Clamp(), onert::shape_inference::StridedSliceParams::start_indices, and onert::shape_inference::StridedSliceParams::strides.

Referenced by inferStridedSliceShape().

◆ StopForAxis()

int onert::shape_inference::StopForAxis ( const StridedSliceParams params,
const ir::Shape input_shape,
int  axis,
int  start_for_axis 
)

Definition at line 956 of file ShapeInference.cc.

958{
959 const auto end_mask = params.end_mask;
960 const auto shrink_axis_mask = params.shrink_axis_mask;
961 const auto *stop_indices = params.stop_indices;
962 const auto *strides = params.strides;
963
964 // Begin with the specified index
965 const bool shrink_axis = shrink_axis_mask & (1 << axis);
966 int stop = stop_indices[axis];
967
968 // When shrinking an axis, the end position does not matter (and can be
969 // incorrect when negative indexing is used, see Issue #19260). Always use
970 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
971 // already been adjusted for negative indices.
972 if (shrink_axis)
973 {
974 stop = start_for_axis + 1;
975 }
976
977 // end_mask override
978 if (end_mask & (1 << axis))
979 {
980 if (strides[axis] > 0)
981 {
982 // Forward iteration - use the last element. These values will get
983 // clamped below
984 stop = std::numeric_limits<int>::max();
985 }
986 else
987 {
988 // Backward iteration - use the first element.
989 stop = std::numeric_limits<int>::lowest();
990 }
991 }
992
993 // Handle negative indices
994
995 const int axis_size = input_shape.dim(axis);
996 if (stop < 0)
997 {
998 stop += axis_size;
999 }
1000
1001 // Clamping
1002 // Because the end index points one past the last element, we need slightly
1003 // different clamping ranges depending on the direction.
1004 if (strides[axis] > 0)
1005 {
1006 // Forward iteration
1007 stop = Clamp(stop, 0, axis_size);
1008 }
1009 else
1010 {
1011 // Backward iteration
1012 stop = Clamp(stop, -1, axis_size - 1);
1013 }
1014
1015 return stop;
1016}

References Clamp(), onert::shape_inference::StridedSliceParams::end_mask, onert::shape_inference::StridedSliceParams::shrink_axis_mask, onert::shape_inference::StridedSliceParams::stop_indices, and onert::shape_inference::StridedSliceParams::strides.

Referenced by inferStridedSliceShape().