ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::shape_inference Namespace Reference

Namespaces

namespace  bcq
 

Data Structures

struct  StridedSliceParams
 

Typedefs

using Shapes = std::vector< ir::Shape >
 

Functions

ir::Shape inferArgMinMaxShape (const ir::Shape &input_shape, int axis, int rank)
 
ir::Shape inferBatchMatMulShape (const ir::Shape &lhs_shape, const ir::Shape &rhs_shape, const ir::operation::BatchMatMul::Param &param)
 
ir::Shape inferBCQFullyConnectedShape (const ir::Shape &in_shape, const ir::Shape &cluster_shape, const int32_t *cluster_buf)
 
ir::Shape inferBCQGatherShape (const ir::Shape &indices_shape, const ir::Shape &cluster_shape, const int32_t *cluster_buf, int rank, const ir::operation::BCQGather::Param &param)
 
ir::Shape inferBroadcastToShape (const ir::Shape shp_shape, const int32_t *shp_buf)
 
ir::Shape inferConcatShape (const Shapes &in_shapes, const ir::operation::Concat::Param &param)
 
ir::Shape inferConv2DShape (const ir::Shape &in_shape, const ir::Shape &ker_shape, const ir::operation::Conv2D::Param &param)
 
ir::Shape inferDepthwiseConv2DShape (const ir::Shape &in_shape, const ir::Shape &ker_shape, const ir::operation::DepthwiseConv2D::Param &param)
 
ir::Shape inferEltwiseShape (const ir::Shape &lhs_shape, const ir::Shape &rhs_shape)
 
ir::Shape inferExpandDimsShape (const ir::Shape &in_shape, int32_t axis)
 
template<typename T >
ir::Shape inferFillShape (const ir::Shape &fill_shape, const T *shape_buf)
 
ir::Shape inferFullyConnectedShape (const ir::Shape &in_shape, const ir::Shape &ker_shape)
 
ir::Shape inferGatherShape (const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis, int rank)
 
ir::Shape inferOnehotShape (const ir::Shape &input_shape, const int depth, int axis)
 
ir::Shape inferPackShape (const ir::Shape &input_shape, int axis, int rank, int num)
 
ir::Shape inferPadShape (const ir::Shape &in_shape, const int32_t *pad_buf, const size_t num_pads)
 
ir::Shape inferPoolShape (const ir::Shape &in_shape, const ir::operation::Pool2D::Param &param)
 
template<typename T >
ir::Shape inferRangeShape (T start_val, T limit_val, T delta_val)
 
ir::Shape inferReshapeShape (const ir::Shape &input_shape, const int32_t *shape_buf, const int32_t shape_num_elements)
 
ir::Shape inferReduceShape (const ir::Shape &input_shape, const std::vector< int > &axes, bool keep_dims)
 
template<float * >
ir::Shape inferRangeShape (float *start_val, float *limit_val, float *delta_val)
 
ir::Shape inferResizeBilinearShape (const ir::Shape &in_shape, const int32_t output_height, const int32_t output_width)
 
ir::Shape inferSelectShape (const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape, const ir::Shape &input_false_shape)
 
template<typename T >
ir::Shape inferSliceShape (const ir::Shape &input_shape, const T *begins_buf, const T *sizes_buf)
 
ir::Shape inferSpaceToBatchNDShape (const ir::Shape &input_shape, const ir::Shape &block_shape_shape, const ir::Shape &padding_shape, const int32_t *block_shape_buf, const int32_t *padding_buf)
 
ir::Shape inferSplitShape (const ir::Shape input_shape, int axis_value, int num_splits)
 
ir::Shape inferSqueezeShape (const ir::Shape &in_shape, const ir::operation::Squeeze::Param &param)
 
template<typename T >
StridedSliceParams buildStridedSliceParams (const T *begin, const T *end, const T *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
 
ir::Shape inferStridedSliceShape (const ir::Shape &input_shape, const StridedSliceParams &op_params, uint32_t rank)
 
ir::Shape inferTileShape (const ir::Shape &in_shape, const int32_t *multiplier_buf, const int32_t multiplier_size)
 
ir::Shape inferTransposeShape (const ir::Shape &in_shape, const int32_t *perm_buf, const int32_t rank)
 
ir::Shape inferUnpackShape (const ir::Shape &input_shape, int axis, int rank)
 
std::pair< int, int > calcConvLikeHeightAndWidth (const int in_h, const int in_w, const int ker_h, const int ker_w, const ir::Padding pad, const ir::Stride stride, const ir::Dilation dilation={1, 1})
 
template ir::Shape inferFillShape (const ir::Shape &fill_shape, const int32_t *shape_buf)
 
template ir::Shape inferFillShape (const ir::Shape &fill_shape, const int64_t *shape_buf)
 
template ir::Shape inferRangeShape (int start_val, int limit_val, int delta_val)
 
template ir::Shape inferRangeShape (float start_val, float limit_val, float delta_val)
 
template ir::Shape inferSliceShape (const ir::Shape &input_shape, const int32_t *begins_buf, const int32_t *sizes_buf)
 
template ir::Shape inferSliceShape (const ir::Shape &input_shape, const int64_t *begins_buf, const int64_t *sizes_buf)
 
template StridedSliceParams buildStridedSliceParams (const uint32_t *begin, const uint32_t *end, const uint32_t *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
 
int Clamp (const int v, const int lo, const int hi)
 
int StartForAxis (const StridedSliceParams &params, const ir::Shape &input_shape, int axis)
 
int StopForAxis (const StridedSliceParams &params, const ir::Shape &input_shape, int axis, int start_for_axis)
 

Typedef Documentation

◆ Shapes

using onert::shape_inference::Shapes = typedef std::vector<ir::Shape>

Definition at line 37 of file ShapeInference.h.

Function Documentation

◆ buildStridedSliceParams() [1/2]

template<typename T >
StridedSliceParams onert::shape_inference::buildStridedSliceParams ( const T *  begin,
const T *  end,
const T *  strides,
const uint32_t  begin_mask,
const uint32_t  end_mask,
const uint32_t  shrink_axis_mask,
const uint8_t  rank 
)

Definition at line 865 of file ShapeInference.cc.

868{
869 StridedSliceParams op_params;
870 op_params.start_indices_count = rank;
871 op_params.stop_indices_count = rank;
872 op_params.strides_count = rank;
873
874 for (int i = 0; i < op_params.strides_count; ++i)
875 {
876 op_params.start_indices[i] = begin[i];
877 op_params.stop_indices[i] = end[i];
878 op_params.strides[i] = strides[i];
879
880 assert(op_params.strides[i] != 0);
881 }
882
883 op_params.begin_mask = begin_mask;
884 op_params.ellipsis_mask = 0; // NYI
885 op_params.end_mask = end_mask;
886 op_params.new_axis_mask = 0; // NYI
887 op_params.shrink_axis_mask = shrink_axis_mask;
888
889 assert(sizeof(op_params.begin_mask) * 4 >= rank);
890
891 return op_params;
892}
int32_t begin[5]
Definition Slice.cpp:33

References begin, onert::shape_inference::StridedSliceParams::begin_mask, onert::shape_inference::StridedSliceParams::ellipsis_mask, onert::shape_inference::StridedSliceParams::end_mask, onert::shape_inference::StridedSliceParams::new_axis_mask, onert::shape_inference::StridedSliceParams::shrink_axis_mask, onert::shape_inference::StridedSliceParams::start_indices, onert::shape_inference::StridedSliceParams::start_indices_count, onert::shape_inference::StridedSliceParams::stop_indices, onert::shape_inference::StridedSliceParams::stop_indices_count, onert::shape_inference::StridedSliceParams::strides, and onert::shape_inference::StridedSliceParams::strides_count.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ buildStridedSliceParams() [2/2]

template StridedSliceParams onert::shape_inference::buildStridedSliceParams ( const uint32_t *  begin,
const uint32_t *  end,
const uint32_t *  strides,
const uint32_t  begin_mask,
const uint32_t  end_mask,
const uint32_t  shrink_axis_mask,
const uint8_t  rank 
)

◆ calcConvLikeHeightAndWidth()

std::pair< int, int > onert::shape_inference::calcConvLikeHeightAndWidth ( const int  in_h,
const int  in_w,
const int  ker_h,
const int  ker_w,
const ir::Padding  pad,
const ir::Stride  stride,
const ir::Dilation  dilation = {1, 1} 
)

Definition at line 90 of file ShapeInference.cc.

93 {1, 1})
94{
95 int32_t out_h = 0, out_w = 0;
96 int32_t effective_filter_w_size = (ker_w - 1) * dilation.width_factor + 1;
97 int32_t effective_filter_h_size = (ker_h - 1) * dilation.height_factor + 1;
98 switch (pad.type)
99 {
100 case ir::PaddingType::SAME:
101 out_h = ceil_div(in_h, stride.vertical);
102 out_w = ceil_div(in_w, stride.horizontal);
103 break;
104 case ir::PaddingType::VALID:
105 out_h = ceil_div(in_h - effective_filter_h_size + 1, stride.vertical);
106 out_w = ceil_div(in_w - effective_filter_w_size + 1, stride.horizontal);
107 break;
108 case ir::PaddingType::EXPLICIT:
109 out_h =
110 (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1;
111 out_w =
112 (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal + 1;
113 break;
114 default:
115 assert(false);
116 }
117
118 return {out_h, out_w};
119}
PaddingType type
Definition Padding.h:59
ExplicitPadding param
Definition Padding.h:60

Referenced by inferConv2DShape(), inferDepthwiseConv2DShape(), and inferPoolShape().

◆ Clamp()

int onert::shape_inference::Clamp ( const int  v,
const int  lo,
const int  hi 
)

Definition at line 900 of file ShapeInference.cc.

901{
902 assert(!(hi < lo));
903 if (hi < v)
904 return hi;
905 if (v < lo)
906 return lo;
907 return v;
908}

Referenced by StartForAxis(), and StopForAxis().

◆ inferArgMinMaxShape()

ir::Shape onert::shape_inference::inferArgMinMaxShape ( const ir::Shape input_shape,
int  axis,
int  rank 
)

Definition at line 126 of file ShapeInference.cc.

127{
128 if (axis < 0 || axis >= rank)
129 {
130 throw std::runtime_error("ArgMinMax shape inference: Wrong axis value " + std::to_string(axis));
131 }
132
133 ir::Shape out_shape;
134 for (int idx = 0; idx < rank; ++idx)
135 {
136 if (idx != axis)
137 {
138 int32_t input_dim = input_shape.dim(idx);
139 out_shape.append(input_dim);
140 }
141 }
142
143 return out_shape;
144}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBatchMatMulShape()

ir::Shape onert::shape_inference::inferBatchMatMulShape ( const ir::Shape lhs_shape,
const ir::Shape rhs_shape,
const ir::operation::BatchMatMul::Param param 
)

Definition at line 228 of file ShapeInference.cc.

230{
231 bool adj_x = param.adj_x;
232 bool adj_y = param.adj_y;
234
235 int output_rank = std::max(lhs_shape.rank(), rhs_shape.rank());
236
237 // Extend lhs and rhs shape
238 ir::Shape extended_lhs_shape(lhs_shape);
239 ir::Shape extended_rhs_shape(rhs_shape);
240 extended_lhs_shape.extendRank(output_rank);
241 extended_rhs_shape.extendRank(output_rank);
242
243 for (int i = 0; i < output_rank - 2; i++)
244 {
245 const int lhs_dim = extended_lhs_shape.dim(i);
246 const int rhs_dim = extended_rhs_shape.dim(i);
247 int broadcast_dim = lhs_dim;
248 if (lhs_dim != rhs_dim)
249 {
250 if (lhs_dim == 1)
251 {
252 broadcast_dim = rhs_dim;
253 }
254 else if (rhs_dim != 1)
255 {
256 throw std::runtime_error{"BatchMatMul shape inference: invalid brodcasting input shape"};
257 }
258 }
259
260 output_shape.append(broadcast_dim);
261 }
262
263 // Fill in the matmul dimensions.
264 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
265 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
266
267 output_shape.append(extended_lhs_shape.dim(lhs_rows_index));
268 output_shape.append(extended_rhs_shape.dim(rhs_cols_index));
269
270 return output_shape;
271}
const luci_interpreter::RuntimeShape output_shape

References onert::ir::operation::BatchMatMul::Param::adj_x, onert::ir::operation::BatchMatMul::Param::adj_y, and output_shape.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBCQFullyConnectedShape()

ir::Shape onert::shape_inference::inferBCQFullyConnectedShape ( const ir::Shape in_shape,
const ir::Shape cluster_shape,
const int32_t *  cluster_buf 
)

Definition at line 407 of file ShapeInference.cc.

409{
410 assert(cluster_shape.rank() == 2);
411 assert(cluster_shape.dim(1) == 2);
412
413 const auto input_size = in_shape.dim(1);
414 const auto output_size = bcq::getOutputSize(cluster_shape, cluster_buf);
415
416 return {ir::Shape({output_size, input_size})};
417}

References onert::shape_inference::bcq::getOutputSize().

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBCQGatherShape()

ir::Shape onert::shape_inference::inferBCQGatherShape ( const ir::Shape indices_shape,
const ir::Shape cluster_shape,
const int32_t *  cluster_buf,
int  rank,
const ir::operation::BCQGather::Param param 
)

Definition at line 419 of file ShapeInference.cc.

422{
423 ir::Shape out_shape;
424 ir::Shape in_original_shape;
425
426 assert(cluster_shape.rank() == 2);
427 assert(cluster_shape.dim(1) == 2);
428
429 auto hidden_size = param.input_hidden_size;
430 auto axis = param.axis;
431
432 in_original_shape.append(bcq::getOutputSize(cluster_shape, cluster_buf));
433 in_original_shape.append(hidden_size);
434
435 const int indices_rank = indices_shape.rank();
436 for (int idx = 0; idx < rank; ++idx)
437 {
438 if (idx == (int)axis)
439 {
440 for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
441 {
442 out_shape.append(indices_shape.dim(indices_idx));
443 }
444 }
445 else
446 {
447 out_shape.append(in_original_shape.dim(idx));
448 }
449 }
450
451 return out_shape;
452}

References onert::ir::operation::BCQGather::Param::axis, onert::shape_inference::bcq::getOutputSize(), and onert::ir::operation::BCQGather::Param::input_hidden_size.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBroadcastToShape()

ir::Shape onert::shape_inference::inferBroadcastToShape ( const ir::Shape  shp_shape,
const int32_t *  shp_buf 
)

Definition at line 277 of file ShapeInference.cc.

278{
279
280 const int num_elements = shp_shape.num_elements();
281
282 assert(num_elements != 0);
283 assert(shp_buf);
284
285 ir::Shape new_shape(num_elements);
286
287 for (int i = 0; i < num_elements; ++i)
288 {
289 assert(shp_buf[i] != 0); // It shouldn't be 0.
290 new_shape.dim(i) = shp_buf[i];
291 }
292
293 return new_shape;
294}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferConcatShape()

ir::Shape onert::shape_inference::inferConcatShape ( const Shapes in_shapes,
const ir::operation::Concat::Param param 
)

Definition at line 296 of file ShapeInference.cc.

297{
298 const int32_t concat_axis = param.axis >= 0 ? param.axis : in_shapes[0].rank() + param.axis;
299 const auto &first_in_shape = in_shapes[0];
300
301 // Check that all shapes are equal except for concat axis dimension
302 for (const auto &in_shape : in_shapes)
303 {
304 if (in_shape.rank() != first_in_shape.rank())
305 throw std::runtime_error("Rank in all input tensors should be same");
306
307 for (int64_t dim_idx = 0; dim_idx < in_shape.rank(); ++dim_idx)
308 if (!(dim_idx == concat_axis || in_shape.dim(dim_idx) == first_in_shape.dim(dim_idx)))
309 throw std::runtime_error("All tensor should have same dimension "
310 "except dimension on passed axis");
311 }
312
313 // Calculate output shape
314 ir::Shape out_shape(first_in_shape);
315 out_shape.dim(concat_axis) = 0;
316 for (const auto &in_shape : in_shapes)
317 out_shape.dim(concat_axis) += in_shape.dim(concat_axis);
318 return out_shape;
319}

References onert::ir::operation::Concat::Param::axis.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferConv2DShape()

ir::Shape onert::shape_inference::inferConv2DShape ( const ir::Shape in_shape,
const ir::Shape ker_shape,
const ir::operation::Conv2D::Param param 
)

Definition at line 321 of file ShapeInference.cc.

323{
324 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
325 throw std::runtime_error{"Conv2D: stride values must be positive"};
326
327 auto ifm_shape = in_shape.asFeature();
328
329 // Kernel format is [depth_out, kernel_height, kernel_width, depth_in]
330 auto kf_shape = ker_shape.asFeature();
331 assert(ifm_shape.C == kf_shape.C);
332
333 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(
334 ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W, param.padding, param.stride, param.dilation);
335
336 return ir::Shape{ifm_shape.N, out_h, out_w, kf_shape.N};
337}

References calcConvLikeHeightAndWidth(), onert::ir::operation::Conv2D::Param::dilation, onert::ir::Stride::horizontal, onert::ir::operation::Conv2D::Param::padding, onert::ir::operation::Conv2D::Param::stride, and onert::ir::Stride::vertical.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferDepthwiseConv2DShape()

ir::Shape onert::shape_inference::inferDepthwiseConv2DShape ( const ir::Shape in_shape,
const ir::Shape ker_shape,
const ir::operation::DepthwiseConv2D::Param param 
)

Definition at line 339 of file ShapeInference.cc.

341{
342 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
343 throw std::runtime_error{"DepthwiseConv2D: stride values must be positive"};
344
345 auto ifm_shape = in_shape.asFeature();
346
347 // Kernel format is [1, kernel_height, kernel_width, depth_out]
348 auto kf_shape = ker_shape.asFeature();
349 assert(kf_shape.C == static_cast<int32_t>(ifm_shape.C * param.multiplier));
350 assert(kf_shape.N == 1);
351
352 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(
353 ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W, param.padding, param.stride, param.dilation);
354
355 return ir::Shape{ifm_shape.N, out_h, out_w, kf_shape.C};
356}

References calcConvLikeHeightAndWidth(), onert::ir::operation::DepthwiseConv2D::Param::dilation, onert::ir::Stride::horizontal, onert::ir::operation::DepthwiseConv2D::Param::multiplier, onert::ir::operation::DepthwiseConv2D::Param::padding, onert::ir::operation::DepthwiseConv2D::Param::stride, and onert::ir::Stride::vertical.

◆ inferEltwiseShape()

ir::Shape onert::shape_inference::inferEltwiseShape ( const ir::Shape lhs_shape,
const ir::Shape rhs_shape 
)

Definition at line 121 of file ShapeInference.cc.

122{
123 return broadcastShapes(lhs_shape, rhs_shape);
124}

◆ inferExpandDimsShape()

ir::Shape onert::shape_inference::inferExpandDimsShape ( const ir::Shape in_shape,
int32_t  axis 
)

Definition at line 358 of file ShapeInference.cc.

359{
360 ir::Shape out_shape(in_shape.rank() + 1);
361
362 axis = ((axis >= 0) ? axis : /* when axis < 0 */ (out_shape.rank() + axis));
363 if (!(0 <= axis && axis <= in_shape.rank()))
364 throw std::runtime_error("axis of dim is out of range");
365
366 for (int x = 0, out_x = 0; out_x < out_shape.rank(); ++out_x)
367 {
368 if (out_x == axis)
369 out_shape.dim(out_x) = 1;
370 else
371 out_shape.dim(out_x) = in_shape.dim(x++);
372 }
373
374 return out_shape;
375}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferFillShape() [1/3]

template ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const int32_t *  shape_buf 
)

◆ inferFillShape() [2/3]

template ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const int64_t *  shape_buf 
)

◆ inferFillShape() [3/3]

template<typename T >
ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const T *  shape_buf 
)

Definition at line 377 of file ShapeInference.cc.

378{
379 ir::Shape out_shape(fill_shape.dim(0));
380
381 for (int out_x = 0; out_x < out_shape.rank(); ++out_x)
382 {
383 out_shape.dim(out_x) = static_cast<int32_t>(shape_buf[out_x]);
384 }
385
386 return out_shape;
387}

◆ inferFullyConnectedShape()

ir::Shape onert::shape_inference::inferFullyConnectedShape ( const ir::Shape in_shape,
const ir::Shape ker_shape 
)

Definition at line 393 of file ShapeInference.cc.

394{
395 assert(in_shape.rank() >= 2);
396 assert(ker_shape.rank() == 2);
397
398 const auto input_size_with_batch = in_shape.num_elements();
399 const auto num_units = ker_shape.dim(0);
400 const auto input_size = ker_shape.dim(1);
401 const auto batch_size = input_size_with_batch / input_size;
402 assert(input_size_with_batch % input_size == 0);
403
404 return {ir::Shape({static_cast<int32_t>(batch_size), num_units})};
405}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferGatherShape()

ir::Shape onert::shape_inference::inferGatherShape ( const ir::Shape input_shape,
const ir::Shape indices_shape,
int  axis,
int  rank 
)

Definition at line 454 of file ShapeInference.cc.

456{
457 ir::Shape out_shape;
458
459 const int indices_rank = indices_shape.rank();
460
461 for (int idx = 0; idx < rank; ++idx)
462 {
463 if (idx == axis)
464 {
465 for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
466 {
467 out_shape.append(indices_shape.dim(indices_idx));
468 }
469 }
470 else
471 {
472 out_shape.append(input_shape.dim(idx));
473 }
474 }
475
476 return out_shape;
477}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferOnehotShape()

ir::Shape onert::shape_inference::inferOnehotShape ( const ir::Shape input_shape,
const int  depth,
int  axis 
)

Definition at line 479 of file ShapeInference.cc.

480{
481 assert(depth >= 0);
482 const auto rank = input_shape.rank() + 1;
483 ir::Shape newShape(rank);
484
485 axis = (axis == -1) ? (rank - 1) : axis;
486
487 for (int i = 0; i < rank; ++i)
488 {
489 if (i < axis)
490 {
491 newShape.dim(i) = input_shape.dim(i);
492 }
493 else if (i == axis)
494 {
495 newShape.dim(i) = depth;
496 }
497 else
498 {
499 newShape.dim(i) = input_shape.dim(i - 1);
500 }
501 }
502
503 return newShape;
504}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPackShape()

ir::Shape onert::shape_inference::inferPackShape ( const ir::Shape input_shape,
int  axis,
int  rank,
int  num 
)

Definition at line 506 of file ShapeInference.cc.

507{
508 ir::Shape out_shape;
509 int in_idx = 0;
510
511 for (int out_idx = 0; out_idx < rank; ++out_idx)
512 {
513 if (out_idx == axis)
514 {
515 out_shape.append(num);
516 }
517 else
518 {
519 out_shape.append(input_shape.dim(in_idx++));
520 }
521 }
522
523 return out_shape;
524}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPadShape()

ir::Shape onert::shape_inference::inferPadShape ( const ir::Shape in_shape,
const int32_t *  pad_buf,
const size_t  num_pads 
)

Definition at line 526 of file ShapeInference.cc.

527{
528 assert(num_pads % 2 == 0);
529 const int32_t rank = num_pads / 2;
530
531 ir::Shape ret(rank);
532 for (int32_t i = 0; i < rank; ++i)
533 {
534 const auto before_padding = pad_buf[i * 2];
535 const auto after_padding = pad_buf[i * 2 + 1];
536
537 ret.dim(i) = in_shape.dim(i) + before_padding + after_padding;
538 }
539
540 return ret;
541}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPoolShape()

ir::Shape onert::shape_inference::inferPoolShape ( const ir::Shape in_shape,
const ir::operation::Pool2D::Param param 
)

Definition at line 543 of file ShapeInference.cc.

544{
545 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
546 throw std::runtime_error{"Pool2D: stride values must be positive"};
547
548 auto ifm_shape = in_shape.asFeature();
549 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh,
550 param.kw, param.padding, param.stride);
551 // Pooling don't change number of channels and batch size
552 return ir::Shape{ifm_shape.N, out_h, out_w, ifm_shape.C};
553}

References calcConvLikeHeightAndWidth(), onert::ir::Stride::horizontal, onert::ir::operation::Pool2D::Param::kh, onert::ir::operation::Pool2D::Param::kw, onert::ir::operation::Pool2D::Param::padding, onert::ir::operation::Pool2D::Param::stride, and onert::ir::Stride::vertical.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferRangeShape() [1/4]

template<float * >
ir::Shape onert::shape_inference::inferRangeShape ( float *  start_val,
float *  limit_val,
float *  delta_val 
)

◆ inferRangeShape() [2/4]

template ir::Shape onert::shape_inference::inferRangeShape ( float  start_val,
float  limit_val,
float  delta_val 
)

◆ inferRangeShape() [3/4]

template ir::Shape onert::shape_inference::inferRangeShape ( int  start_val,
int  limit_val,
int  delta_val 
)

◆ inferRangeShape() [4/4]

template<typename T >
ir::Shape onert::shape_inference::inferRangeShape ( start_val,
limit_val,
delta_val 
)

Definition at line 580 of file ShapeInference.cc.

581{
582 ir::Shape out_shape(static_cast<int>(1));
583
584 out_shape.dim(0) =
585 (std::is_integral<T>::value
586 ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
587 : std::ceil(std::abs((start_val - limit_val) / delta_val)));
588 return out_shape;
589}

◆ inferReduceShape()

ir::Shape onert::shape_inference::inferReduceShape ( const ir::Shape input_shape,
const std::vector< int > &  axes,
bool  keep_dims 
)

Definition at line 146 of file ShapeInference.cc.

148{
149 int num_axis = axes.size();
150 int input_num_dims = input_shape.rank();
151 if (input_num_dims == 0)
152 {
153 ir::Shape out_shape(0);
154 return out_shape;
155 }
156 if (keep_dims)
157 {
158 ir::Shape out_shape;
159 for (int idx = 0; idx < input_num_dims; ++idx)
160 {
161 bool is_axis = false;
162 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
163 {
164 if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
165 {
166 is_axis = true;
167 break;
168 }
169 }
170 if (is_axis)
171 {
172 out_shape.append(1);
173 }
174 else
175 {
176 out_shape.append(input_shape.dim(idx));
177 }
178 }
179 return out_shape;
180 }
181 else
182 {
183 // Calculates size of reducing axis.
184 for (int i = 0; i < num_axis; ++i)
185 {
186 int current = axes[i];
187 if (!(-input_num_dims <= current && current < input_num_dims))
188 throw std::runtime_error{"Invalid dim value " + std::to_string(current)};
189 if (current < 0)
190 {
191 current += input_num_dims;
192 }
193 for (int j = 0; j < i; ++j)
194 {
195 int previous = axes[j];
196 if (previous < 0)
197 {
198 previous += input_num_dims;
199 }
200 if (current == previous)
201 {
202 break;
203 }
204 }
205 }
206 // Determines output dimensions.
207 ir::Shape out_shape;
208 for (int idx = 0; idx < input_num_dims; ++idx)
209 {
210 bool is_axis = false;
211 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
212 {
213 if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
214 {
215 is_axis = true;
216 break;
217 }
218 }
219 if (!is_axis)
220 {
221 out_shape.append(input_shape.dim(idx));
222 }
223 }
224 return out_shape;
225 }
226}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferReshapeShape()

ir::Shape onert::shape_inference::inferReshapeShape ( const ir::Shape input_shape,
const int32_t *  shape_buf,
const int32_t  shape_num_elements 
)

Definition at line 595 of file ShapeInference.cc.

597{
598 ir::Shape ret(shape_num_elements);
599 int32_t flatten_dim = ir::Shape::kUnspecifiedDim;
600 auto total_num_elements = input_shape.num_elements();
601 for (int32_t i = 0; i < shape_num_elements; ++i)
602 {
603 if (shape_buf[i] < 0)
604 {
605 if (flatten_dim != ir::Shape::kUnspecifiedDim)
606 throw std::runtime_error("Reshape: 2nd param has special dim(for flatten) more than twice");
607 flatten_dim = i;
608 ret.dim(i) = 1;
609 }
610 else
611 {
612 ret.dim(i) = shape_buf[i];
613 }
614 }
615 if (flatten_dim != ir::Shape::kUnspecifiedDim)
616 ret.dim(flatten_dim) = total_num_elements / ret.num_elements();
617
618 // Check reshapable
619 if (total_num_elements != static_cast<size_t>(ret.num_elements()))
620 {
621 // Multi batch case
622 // TODO Handle multi batch case more precisely on runtime level
623 if ((ret.dim(0) == 1) &&
624 (total_num_elements == static_cast<size_t>(ret.num_elements() * input_shape.dim(0))))
625 ret.dim(0) = input_shape.dim(0);
626 else
627 throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
628 }
629
630 return ret;
631}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferResizeBilinearShape()

ir::Shape onert::shape_inference::inferResizeBilinearShape ( const ir::Shape in_shape,
const int32_t  output_height,
const int32_t  output_width 
)

Definition at line 555 of file ShapeInference.cc.

557{
558 assert(in_shape.rank() == 4);
559 if (output_height < 0)
560 {
561 throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_height = " +
562 std::to_string(output_height)};
563 }
564 if (output_width < 0)
565 {
566 throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_width = " +
567 std::to_string(output_width)};
568 }
569
570 ir::Shape ret(in_shape.rank());
571
572 ret.dim(0) = in_shape.dim(0);
573 ret.dim(1) = output_height;
574 ret.dim(2) = output_width;
575 ret.dim(3) = in_shape.dim(3);
576
577 return ret;
578}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSelectShape()

ir::Shape onert::shape_inference::inferSelectShape ( const ir::Shape input_cond_shape,
const ir::Shape input_true_shape,
const ir::Shape input_false_shape 
)

Definition at line 633 of file ShapeInference.cc.

635{
636 auto haveSameShapes = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
637 const ir::Shape &input_false_shape) {
638 if ((input_cond_shape.rank() != input_true_shape.rank()) ||
639 input_cond_shape.rank() != input_false_shape.rank())
640 {
641 return false;
642 }
643
644 int rank = input_cond_shape.rank();
645 for (int i = 0; i < rank; ++i)
646 {
647 if (input_cond_shape.dim(i) != input_true_shape.dim(i) ||
648 input_cond_shape.dim(i) != input_false_shape.dim(i))
649 {
650 return false;
651 }
652 }
653
654 return true;
655 };
656
657 auto calculateShape = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
658 const ir::Shape &input_false_shape, ir::Shape &new_shape) {
659 ir::Shape cond_shape = input_cond_shape;
660 ir::Shape true_shape = input_true_shape;
661 ir::Shape false_shape = input_false_shape;
662 int most_rank =
663 (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
664 ? cond_shape.rank()
665 : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
666
667 ir::Shape calculate_shape(most_rank);
668
669 cond_shape.extendRank(most_rank);
670 true_shape.extendRank(most_rank);
671 false_shape.extendRank(most_rank);
672
673 for (int i = 0; i < most_rank; ++i)
674 {
675 calculate_shape.dim(i) =
676 (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
677 ? cond_shape.dim(i)
678 : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
679
680 if ((cond_shape.dim(i) != calculate_shape.dim(i) && cond_shape.dim(i) != 1) ||
681 (true_shape.dim(i) != calculate_shape.dim(i) && true_shape.dim(i) != 1) ||
682 (false_shape.dim(i) != calculate_shape.dim(i) && false_shape.dim(i) != 1))
683 {
684 return false;
685 }
686 }
687
688 new_shape = calculate_shape;
689
690 return true;
691 };
692
693 bool havesame = haveSameShapes(input_cond_shape, input_true_shape, input_false_shape);
694 if (havesame)
695 {
696 return input_cond_shape;
697 }
698
699 ir::Shape new_shape;
700 bool possible = calculateShape(input_cond_shape, input_true_shape, input_false_shape, new_shape);
701
702 if (!possible)
703 {
704 throw std::runtime_error("Broadcasting is not possible.");
705 }
706
707 return new_shape;
708}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSliceShape() [1/3]

template ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const int32_t *  begins_buf,
const int32_t *  sizes_buf 
)

◆ inferSliceShape() [2/3]

template ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const int64_t *  begins_buf,
const int64_t *  sizes_buf 
)

◆ inferSliceShape() [3/3]

template<typename T >
ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const T *  begins_buf,
const T *  sizes_buf 
)

Definition at line 711 of file ShapeInference.cc.

712{
713 const uint32_t rank = input_shape.rank();
714 ir::Shape out_shape(rank);
715
716 for (uint32_t idx = 0; idx < rank; ++idx)
717 {
718 const auto input_dim = input_shape.dim(idx);
719
720 // begin is zero-based
721 auto begin = begins_buf[idx];
722 if (begin < 0)
723 throw std::runtime_error("shape inference Slice: Invalid begin.");
724
725 // size is one-based
726 auto size = sizes_buf[idx];
727 if (size < -1)
728 throw std::runtime_error("shape inference Slice: Invalid size.");
729
730 if (size == -1)
731 {
732 size = input_dim - begin;
733 }
734 else
735 {
736 if (input_dim < static_cast<int32_t>(begin + size))
737 throw std::runtime_error("shape inference Slice: Invalid begin and size.");
738 }
739 out_shape.dim(idx) = static_cast<int32_t>(size);
740 }
741
742 return out_shape;
743}
int32_t size[5]
Definition Slice.cpp:35

References begin, and size.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSpaceToBatchNDShape()

ir::Shape onert::shape_inference::inferSpaceToBatchNDShape ( const ir::Shape input_shape,
const ir::Shape block_shape_shape,
const ir::Shape padding_shape,
const int32_t *  block_shape_buf,
const int32_t *  padding_buf 
)

Definition at line 750 of file ShapeInference.cc.

754{
755 const uint32_t rank = input_shape.rank();
756 ir::Shape out_shape(rank);
757
758 // Currently, only 4D NHWC input/output op_context are supported.
759 // The 4D array need to have exactly 2 spatial dimensions.
760 // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
761 [[maybe_unused]] const int32_t kInputDimensionNum = 4;
762 [[maybe_unused]] const int32_t kBlockSizeDimensionNum = 1;
763 const int32_t kSpatialDimensionNum = 2;
764
765 assert(block_shape_shape.rank() == kBlockSizeDimensionNum);
766 assert(block_shape_shape.dim(0) == kSpatialDimensionNum);
767 assert(padding_shape.dim(0) == kSpatialDimensionNum);
768 assert(padding_shape.dim(1) == 2); // fixed, meaning left/right padding for each element
769 assert(padding_shape.rank() == 2); // fixed, meaning dimension(dim 0) and padding length(dim 1)
770
771 // Ensures the input height and width (with padding) is a multiple of block
772 // shape height and width.
773 for (int dim = 0; dim < kSpatialDimensionNum; ++dim)
774 {
775 int final_dim_size =
776 (input_shape.dim(dim + 1) + padding_buf[dim * 2] + padding_buf[dim * 2 + 1]);
777
778 assert(final_dim_size % block_shape_buf[dim] == 0);
779
780 out_shape.dim(dim + 1) = final_dim_size / block_shape_buf[dim];
781 }
782
783 const int output_batch_size = input_shape.dim(0) * block_shape_buf[0] * block_shape_buf[1];
784 const int output_channel_size = input_shape.dim(3);
785
786 out_shape.dim(0) = output_batch_size;
787 out_shape.dim(3) = output_channel_size;
788
789 return out_shape;
790}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSplitShape()

ir::Shape onert::shape_inference::inferSplitShape ( const ir::Shape  input_shape,
int  axis_value,
int  num_splits 
)

Definition at line 792 of file ShapeInference.cc.

793{
794 ir::Shape newShape(input_shape);
795
796 assert(axis_value >= 0);
797 assert(axis_value < input_shape.rank());
798
799 const int input_size = input_shape.dim(axis_value);
800 assert(input_size % num_splits == 0);
801 const int slice_size = input_size / num_splits;
802
803 newShape.dim(axis_value) = slice_size;
804
805 return newShape;
806}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSqueezeShape()

ir::Shape onert::shape_inference::inferSqueezeShape ( const ir::Shape in_shape,
const ir::operation::Squeeze::Param param 
)

Definition at line 808 of file ShapeInference.cc.

809{
810 const int ndims = param.ndim;
811 const int *squeeze_dims = param.dims;
812 bool should_squeeze[8] = {false};
813 int num_squeezed_dims = 0;
814 int shape_rank = in_shape.rank();
815 if (ndims == 0)
816 {
817 for (int idx = 0; idx < shape_rank; ++idx)
818 {
819 if (in_shape.dim(idx) == 1)
820 {
821 should_squeeze[idx] = true;
822 ++num_squeezed_dims;
823 }
824 }
825 }
826 else
827 {
828 for (int idx = 0; idx < ndims; ++idx)
829 {
830 int current = squeeze_dims[idx];
831 if (current < 0)
832 {
833 current += shape_rank;
834 }
835
836 if (!(current >= 0 && current < shape_rank && in_shape.dim(current) == 1))
837 {
838 throw std::runtime_error(
839 "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
840 }
841
842 if (!should_squeeze[current])
843 {
844 ++num_squeezed_dims;
845 }
846 should_squeeze[current] = true;
847 }
848 }
849
850 // Set output shape.
851 ir::Shape out_shape(shape_rank - num_squeezed_dims);
852 for (int in_idx = 0, out_idx = 0; in_idx < shape_rank; ++in_idx)
853 {
854 if (!should_squeeze[in_idx])
855 {
856 out_shape.dim(out_idx++) = in_shape.dim(in_idx);
857 }
858 }
859
860 return out_shape;
861}

References onert::ir::operation::Squeeze::Param::dims, and onert::ir::operation::Squeeze::Param::ndim.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferStridedSliceShape()

ir::Shape onert::shape_inference::inferStridedSliceShape ( const ir::Shape input_shape,
const StridedSliceParams op_params,
uint32_t  rank 
)

Definition at line 1015 of file ShapeInference.cc.

1017{
1018 ir::Shape out_shape;
1019
1020 for (uint32_t idx = 0; idx < rank; ++idx)
1021 {
1022 int32_t stride = op_params.strides[idx];
1023 int32_t begin = StartForAxis(op_params, input_shape, idx);
1024 int32_t end = StopForAxis(op_params, input_shape, idx, begin);
1025
1026 // When shrinking an axis, the end position does not matter (and can be
1027 // incorrect when negative indexing is used, see Issue #19260). Always use
1028 // begin + 1 to generate a length 1 slice, since begin has
1029 // already been adjusted for negative indices by StartForAxis.
1030 const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
1031 if (shrink_axis)
1032 {
1033 end = begin + 1;
1034 }
1035
1036 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
1037 dim_shape = dim_shape < 0 ? 0 : dim_shape;
1038 if (!shrink_axis)
1039 {
1040 out_shape.append(dim_shape);
1041 }
1042 }
1043
1044 return out_shape;
1045}

References begin, onert::shape_inference::StridedSliceParams::shrink_axis_mask, StartForAxis(), StopForAxis(), and onert::shape_inference::StridedSliceParams::strides.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferTileShape()

ir::Shape onert::shape_inference::inferTileShape ( const ir::Shape in_shape,
const int32_t *  multiplier_buf,
const int32_t  multiplier_size 
)

Definition at line 1047 of file ShapeInference.cc.

1049{
1050 if (multiplier_size != in_shape.rank())
1051 {
1052 throw std::runtime_error(
1053 "inferTileShape failed, input rank: " + std::to_string(in_shape.rank()) +
1054 ", bad multipliers size: " + std::to_string(multiplier_size) + "");
1055 }
1056 ir::Shape new_Shape(in_shape.rank());
1057
1058 for (int i = 0; i < in_shape.rank(); ++i)
1059 {
1060 assert(multiplier_buf[i]); // multiplier_buf[i] shuld not be 0.
1061 new_Shape.dim(i) = in_shape.dim(i) * multiplier_buf[i];
1062 }
1063 return new_Shape;
1064}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferTransposeShape()

ir::Shape onert::shape_inference::inferTransposeShape ( const ir::Shape in_shape,
const int32_t *  perm_buf,
const int32_t  rank 
)

Definition at line 1066 of file ShapeInference.cc.

1068{
1069 const auto rank = in_shape.rank();
1070 if (perm_size > rank)
1071 {
1072 throw std::runtime_error("inferTransposeShape failed, bad permutation size: " +
1073 std::to_string(perm_size));
1074 }
1075
1076 const int32_t *perm_data = perm_buf;
1077 std::vector<int32_t> regular_perm_vec;
1078 if (perm_size == 0)
1079 {
1080 // perm_data will be set to (n-1...0)
1081 regular_perm_vec.resize(rank);
1082 std::iota(regular_perm_vec.begin(), regular_perm_vec.end(), 0);
1083 std::reverse(regular_perm_vec.begin(), regular_perm_vec.end());
1084 perm_data = regular_perm_vec.data();
1085 }
1086 else
1087 {
1088 assert(rank == perm_size);
1089 }
1090
1091 ir::Shape out_shape(rank);
1092 std::vector<bool> visit_perms(rank, false);
1093 for (int idx = 0; idx < rank; idx++)
1094 {
1095 const auto perm_val = perm_data[idx];
1096 // Check invalid permutation value
1097 if (perm_val < 0 || perm_val >= rank)
1098 {
1099 throw std::runtime_error("inferTransposeShape failed, bad permutation value: " +
1100 std::to_string(perm_val));
1101 }
1102
1103 // Check duplicated permutation value
1104 if (visit_perms.at(perm_val))
1105 {
1106 throw std::runtime_error("inferTransposeShape failed, duplicated permutation value: " +
1107 std::to_string(perm_val));
1108 }
1109 visit_perms.at(perm_val) = true;
1110
1111 out_shape.dim(idx) = in_shape.dim(perm_val);
1112 }
1113 return out_shape;
1114}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferUnpackShape()

ir::Shape onert::shape_inference::inferUnpackShape ( const ir::Shape input_shape,
int  axis,
int  rank 
)

Definition at line 1116 of file ShapeInference.cc.

1117{
1118 ir::Shape out_shape;
1119
1120 for (int out_idx = 0; out_idx < rank; out_idx++)
1121 {
1122 if (out_idx != axis)
1123 {
1124 out_shape.append(input_shape.dim(out_idx));
1125 }
1126 }
1127
1128 return out_shape;
1129}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ StartForAxis()

int onert::shape_inference::StartForAxis ( const StridedSliceParams params,
const ir::Shape input_shape,
int  axis 
)

Definition at line 910 of file ShapeInference.cc.

911{
912 const auto begin_mask = params.begin_mask;
913 const auto *start_indices = params.start_indices;
914 const auto *strides = params.strides;
915 // Begin with the specified index.
916 int start = start_indices[axis];
917
918 // begin_mask override
919 if (begin_mask & 1 << axis)
920 {
921 if (strides[axis] > 0)
922 {
923 // Forward iteration - use the first element. These values will get
924 // clamped below (Note: We could have set them to 0 and axis_size-1, but
925 // use lowest() and max() to maintain symmetry with StopForAxis())
926 start = std::numeric_limits<int>::lowest();
927 }
928 else
929 {
930 // Backward iteration - use the last element.
931 start = std::numeric_limits<int>::max();
932 }
933 }
934
935 // Handle negative indices
936 int axis_size = input_shape.dim(axis);
937 if (start < 0)
938 {
939 start += axis_size;
940 }
941
942 // Clamping
943 start = Clamp(start, 0, axis_size - 1);
944
945 return start;
946}
int Clamp(const int32_t v, const int32_t lo, const int32_t hi)

References onert::shape_inference::StridedSliceParams::begin_mask, Clamp(), onert::shape_inference::StridedSliceParams::start_indices, and onert::shape_inference::StridedSliceParams::strides.

Referenced by inferStridedSliceShape().

◆ StopForAxis()

int onert::shape_inference::StopForAxis ( const StridedSliceParams params,
const ir::Shape input_shape,
int  axis,
int  start_for_axis 
)

Definition at line 953 of file ShapeInference.cc.

955{
956 const auto end_mask = params.end_mask;
957 const auto shrink_axis_mask = params.shrink_axis_mask;
958 const auto *stop_indices = params.stop_indices;
959 const auto *strides = params.strides;
960
961 // Begin with the specified index
962 const bool shrink_axis = shrink_axis_mask & (1 << axis);
963 int stop = stop_indices[axis];
964
965 // When shrinking an axis, the end position does not matter (and can be
966 // incorrect when negative indexing is used, see Issue #19260). Always use
967 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
968 // already been adjusted for negative indices.
969 if (shrink_axis)
970 {
971 stop = start_for_axis + 1;
972 }
973
974 // end_mask override
975 if (end_mask & (1 << axis))
976 {
977 if (strides[axis] > 0)
978 {
979 // Forward iteration - use the last element. These values will get
980 // clamped below
981 stop = std::numeric_limits<int>::max();
982 }
983 else
984 {
985 // Backward iteration - use the first element.
986 stop = std::numeric_limits<int>::lowest();
987 }
988 }
989
990 // Handle negative indices
991
992 const int axis_size = input_shape.dim(axis);
993 if (stop < 0)
994 {
995 stop += axis_size;
996 }
997
998 // Clamping
999 // Because the end index points one past the last element, we need slightly
1000 // different clamping ranges depending on the direction.
1001 if (strides[axis] > 0)
1002 {
1003 // Forward iteration
1004 stop = Clamp(stop, 0, axis_size);
1005 }
1006 else
1007 {
1008 // Backward iteration
1009 stop = Clamp(stop, -1, axis_size - 1);
1010 }
1011
1012 return stop;
1013}

References Clamp(), onert::shape_inference::StridedSliceParams::end_mask, onert::shape_inference::StridedSliceParams::shrink_axis_mask, onert::shape_inference::StridedSliceParams::stop_indices, and onert::shape_inference::StridedSliceParams::strides.

Referenced by inferStridedSliceShape().