ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::shape_inference Namespace Reference

Namespaces

namespace  bcq
 

Data Structures

struct  StridedSliceParams
 

Typedefs

using Shapes = std::vector< ir::Shape >
 

Functions

ir::Shape inferArgMinMaxShape (const ir::Shape &input_shape, int axis, int rank)
 
ir::Shape inferBatchMatMulShape (const ir::Shape &lhs_shape, const ir::Shape &rhs_shape, const ir::operation::BatchMatMul::Param &param)
 
ir::Shape inferBCQFullyConnectedShape (const ir::Shape &in_shape, const ir::Shape &cluster_shape, const int32_t *cluster_buf)
 
ir::Shape inferBCQGatherShape (const ir::Shape &indices_shape, const ir::Shape &cluster_shape, const int32_t *cluster_buf, int rank, const ir::operation::BCQGather::Param &param)
 
ir::Shape inferBroadcastToShape (const ir::Shape shp_shape, const int32_t *shp_buf)
 
ir::Shape inferConcatShape (const Shapes &in_shapes, const ir::operation::Concat::Param &param)
 
ir::Shape inferConv2DShape (const ir::Shape &in_shape, const ir::Shape &ker_shape, const ir::operation::Conv2D::Param &param)
 
ir::Shape inferDepthwiseConv2DShape (const ir::Shape &in_shape, const ir::Shape &ker_shape, const ir::operation::DepthwiseConv2D::Param &param)
 
ir::Shape inferEltwiseShape (const ir::Shape &lhs_shape, const ir::Shape &rhs_shape)
 
ir::Shape inferExpandDimsShape (const ir::Shape &in_shape, int32_t axis)
 
template<typename T >
ir::Shape inferFillShape (const ir::Shape &fill_shape, const T *shape_buf)
 
ir::Shape inferFullyConnectedShape (const ir::Shape &in_shape, const ir::Shape &ker_shape, bool keep_num_dims)
 
ir::Shape inferGatherShape (const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis, int rank)
 
ir::Shape inferOnehotShape (const ir::Shape &input_shape, const int depth, int axis)
 
ir::Shape inferPackShape (const ir::Shape &input_shape, int axis, int rank, int num)
 
ir::Shape inferPadShape (const ir::Shape &in_shape, const int32_t *pad_buf, const size_t num_pads)
 
ir::Shape inferPoolShape (const ir::Shape &in_shape, const ir::operation::Pool2D::Param &param)
 
template<typename T >
ir::Shape inferRangeShape (T start_val, T limit_val, T delta_val)
 
ir::Shape inferReshapeShape (const ir::Shape &input_shape, const int32_t *shape_buf, const int32_t shape_num_elements)
 
ir::Shape inferReduceShape (const ir::Shape &input_shape, const std::vector< int > &axes, bool keep_dims)
 
template<float * >
ir::Shape inferRangeShape (float *start_val, float *limit_val, float *delta_val)
 
ir::Shape inferResizeBilinearShape (const ir::Shape &in_shape, const int32_t output_height, const int32_t output_width)
 
ir::Shape inferSelectShape (const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape, const ir::Shape &input_false_shape)
 
template<typename T >
ir::Shape inferSliceShape (const ir::Shape &input_shape, const T *begins_buf, const T *sizes_buf)
 
ir::Shape inferSpaceToBatchNDShape (const ir::Shape &input_shape, const ir::Shape &block_shape_shape, const ir::Shape &padding_shape, const int32_t *block_shape_buf, const int32_t *padding_buf)
 
ir::Shape inferSplitShape (const ir::Shape input_shape, int axis_value, int num_splits)
 
ir::Shape inferSqueezeShape (const ir::Shape &in_shape, const ir::operation::Squeeze::Param &param)
 
template<typename T >
StridedSliceParams buildStridedSliceParams (const T *begin, const T *end, const T *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
 
ir::Shape inferStridedSliceShape (const ir::Shape &input_shape, const StridedSliceParams &op_params, uint32_t rank)
 
ir::Shape inferTileShape (const ir::Shape &in_shape, const int32_t *multiplier_buf, const int32_t multiplier_size)
 
ir::Shape inferTransposeShape (const ir::Shape &in_shape, const int32_t *perm_buf, const int32_t rank)
 
ir::Shape inferUnpackShape (const ir::Shape &input_shape, int axis, int rank)
 
std::pair< int, int > calcConvLikeHeightAndWidth (const int in_h, const int in_w, const int ker_h, const int ker_w, const ir::Padding pad, const ir::Stride stride, const ir::Dilation dilation={1, 1})
 
template ir::Shape inferFillShape (const ir::Shape &fill_shape, const int32_t *shape_buf)
 
template ir::Shape inferFillShape (const ir::Shape &fill_shape, const int64_t *shape_buf)
 
template ir::Shape inferRangeShape (int start_val, int limit_val, int delta_val)
 
template ir::Shape inferRangeShape (float start_val, float limit_val, float delta_val)
 
template ir::Shape inferSliceShape (const ir::Shape &input_shape, const int32_t *begins_buf, const int32_t *sizes_buf)
 
template ir::Shape inferSliceShape (const ir::Shape &input_shape, const int64_t *begins_buf, const int64_t *sizes_buf)
 
template StridedSliceParams buildStridedSliceParams (const uint32_t *begin, const uint32_t *end, const uint32_t *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
 
int Clamp (const int v, const int lo, const int hi)
 
int StartForAxis (const StridedSliceParams &params, const ir::Shape &input_shape, int axis)
 
int StopForAxis (const StridedSliceParams &params, const ir::Shape &input_shape, int axis, int start_for_axis)
 

Typedef Documentation

◆ Shapes

using onert::shape_inference::Shapes = typedef std::vector<ir::Shape>

Definition at line 37 of file ShapeInference.h.

Function Documentation

◆ buildStridedSliceParams() [1/2]

template<typename T >
StridedSliceParams onert::shape_inference::buildStridedSliceParams ( const T *  begin,
const T *  end,
const T *  strides,
const uint32_t  begin_mask,
const uint32_t  end_mask,
const uint32_t  shrink_axis_mask,
const uint8_t  rank 
)

Definition at line 874 of file ShapeInference.cc.

877{
878 StridedSliceParams op_params;
879 op_params.start_indices_count = rank;
880 op_params.stop_indices_count = rank;
881 op_params.strides_count = rank;
882
883 for (int i = 0; i < op_params.strides_count; ++i)
884 {
885 op_params.start_indices[i] = begin[i];
886 op_params.stop_indices[i] = end[i];
887 op_params.strides[i] = strides[i];
888
889 assert(op_params.strides[i] != 0);
890 }
891
892 op_params.begin_mask = begin_mask;
893 op_params.ellipsis_mask = 0; // NYI
894 op_params.end_mask = end_mask;
895 op_params.new_axis_mask = 0; // NYI
896 op_params.shrink_axis_mask = shrink_axis_mask;
897
898 assert(sizeof(op_params.begin_mask) * 4 >= rank);
899
900 return op_params;
901}
int32_t begin[5]
Definition Slice.cpp:33

References begin, onert::shape_inference::StridedSliceParams::begin_mask, onert::shape_inference::StridedSliceParams::ellipsis_mask, onert::shape_inference::StridedSliceParams::end_mask, onert::shape_inference::StridedSliceParams::new_axis_mask, onert::shape_inference::StridedSliceParams::shrink_axis_mask, onert::shape_inference::StridedSliceParams::start_indices, onert::shape_inference::StridedSliceParams::start_indices_count, onert::shape_inference::StridedSliceParams::stop_indices, onert::shape_inference::StridedSliceParams::stop_indices_count, onert::shape_inference::StridedSliceParams::strides, and onert::shape_inference::StridedSliceParams::strides_count.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ buildStridedSliceParams() [2/2]

template StridedSliceParams onert::shape_inference::buildStridedSliceParams ( const uint32_t *  begin,
const uint32_t *  end,
const uint32_t *  strides,
const uint32_t  begin_mask,
const uint32_t  end_mask,
const uint32_t  shrink_axis_mask,
const uint8_t  rank 
)

◆ calcConvLikeHeightAndWidth()

std::pair< int, int > onert::shape_inference::calcConvLikeHeightAndWidth ( const int  in_h,
const int  in_w,
const int  ker_h,
const int  ker_w,
const ir::Padding  pad,
const ir::Stride  stride,
const ir::Dilation  dilation = {1, 1} 
)

Definition at line 90 of file ShapeInference.cc.

93 {1, 1})
94{
95 int32_t out_h = 0, out_w = 0;
96 int32_t effective_filter_w_size = (ker_w - 1) * dilation.width_factor + 1;
97 int32_t effective_filter_h_size = (ker_h - 1) * dilation.height_factor + 1;
98 switch (pad.type)
99 {
100 case ir::PaddingType::SAME:
101 out_h = ceil_div(in_h, stride.vertical);
102 out_w = ceil_div(in_w, stride.horizontal);
103 break;
104 case ir::PaddingType::VALID:
105 out_h = ceil_div(in_h - effective_filter_h_size + 1, stride.vertical);
106 out_w = ceil_div(in_w - effective_filter_w_size + 1, stride.horizontal);
107 break;
108 case ir::PaddingType::EXPLICIT:
109 out_h =
110 (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1;
111 out_w =
112 (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal + 1;
113 break;
114 default:
115 assert(false);
116 }
117
118 return {out_h, out_w};
119}
PaddingType type
Definition Padding.h:59
ExplicitPadding param
Definition Padding.h:60

Referenced by inferConv2DShape(), inferDepthwiseConv2DShape(), and inferPoolShape().

◆ Clamp()

int onert::shape_inference::Clamp ( const int  v,
const int  lo,
const int  hi 
)

Definition at line 909 of file ShapeInference.cc.

910{
911 assert(!(hi < lo));
912 if (hi < v)
913 return hi;
914 if (v < lo)
915 return lo;
916 return v;
917}

Referenced by StartForAxis(), and StopForAxis().

◆ inferArgMinMaxShape()

ir::Shape onert::shape_inference::inferArgMinMaxShape ( const ir::Shape input_shape,
int  axis,
int  rank 
)

Definition at line 126 of file ShapeInference.cc.

127{
128 if (axis < 0 || axis >= rank)
129 {
130 throw std::runtime_error("ArgMinMax shape inference: Wrong axis value " + std::to_string(axis));
131 }
132
133 ir::Shape out_shape;
134 for (int idx = 0; idx < rank; ++idx)
135 {
136 if (idx != axis)
137 {
138 int32_t input_dim = input_shape.dim(idx);
139 out_shape.append(input_dim);
140 }
141 }
142
143 return out_shape;
144}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBatchMatMulShape()

ir::Shape onert::shape_inference::inferBatchMatMulShape ( const ir::Shape lhs_shape,
const ir::Shape rhs_shape,
const ir::operation::BatchMatMul::Param param 
)

Definition at line 228 of file ShapeInference.cc.

230{
231 bool adj_x = param.adj_x;
232 bool adj_y = param.adj_y;
234
235 int output_rank = std::max(lhs_shape.rank(), rhs_shape.rank());
236
237 // Extend lhs and rhs shape
238 ir::Shape extended_lhs_shape(lhs_shape);
239 ir::Shape extended_rhs_shape(rhs_shape);
240 extended_lhs_shape.extendRank(output_rank);
241 extended_rhs_shape.extendRank(output_rank);
242
243 for (int i = 0; i < output_rank - 2; i++)
244 {
245 const int lhs_dim = extended_lhs_shape.dim(i);
246 const int rhs_dim = extended_rhs_shape.dim(i);
247 int broadcast_dim = lhs_dim;
248 if (lhs_dim != rhs_dim)
249 {
250 if (lhs_dim == 1)
251 {
252 broadcast_dim = rhs_dim;
253 }
254 else if (rhs_dim != 1)
255 {
256 throw std::runtime_error{"BatchMatMul shape inference: invalid brodcasting input shape"};
257 }
258 }
259
260 output_shape.append(broadcast_dim);
261 }
262
263 // Fill in the matmul dimensions.
264 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
265 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
266
267 output_shape.append(extended_lhs_shape.dim(lhs_rows_index));
268 output_shape.append(extended_rhs_shape.dim(rhs_cols_index));
269
270 return output_shape;
271}
const luci_interpreter::RuntimeShape output_shape

References onert::ir::operation::BatchMatMul::Param::adj_x, onert::ir::operation::BatchMatMul::Param::adj_y, and output_shape.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBCQFullyConnectedShape()

ir::Shape onert::shape_inference::inferBCQFullyConnectedShape ( const ir::Shape in_shape,
const ir::Shape cluster_shape,
const int32_t *  cluster_buf 
)

Definition at line 416 of file ShapeInference.cc.

418{
419 assert(cluster_shape.rank() == 2);
420 assert(cluster_shape.dim(1) == 2);
421
422 const auto input_size = in_shape.dim(1);
423 const auto output_size = bcq::getOutputSize(cluster_shape, cluster_buf);
424
425 return {ir::Shape({output_size, input_size})};
426}

References onert::shape_inference::bcq::getOutputSize().

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBCQGatherShape()

ir::Shape onert::shape_inference::inferBCQGatherShape ( const ir::Shape indices_shape,
const ir::Shape cluster_shape,
const int32_t *  cluster_buf,
int  rank,
const ir::operation::BCQGather::Param param 
)

Definition at line 428 of file ShapeInference.cc.

431{
432 ir::Shape out_shape;
433 ir::Shape in_original_shape;
434
435 assert(cluster_shape.rank() == 2);
436 assert(cluster_shape.dim(1) == 2);
437
438 auto hidden_size = param.input_hidden_size;
439 auto axis = param.axis;
440
441 in_original_shape.append(bcq::getOutputSize(cluster_shape, cluster_buf));
442 in_original_shape.append(hidden_size);
443
444 const int indices_rank = indices_shape.rank();
445 for (int idx = 0; idx < rank; ++idx)
446 {
447 if (idx == (int)axis)
448 {
449 for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
450 {
451 out_shape.append(indices_shape.dim(indices_idx));
452 }
453 }
454 else
455 {
456 out_shape.append(in_original_shape.dim(idx));
457 }
458 }
459
460 return out_shape;
461}

References onert::ir::operation::BCQGather::Param::axis, onert::shape_inference::bcq::getOutputSize(), and onert::ir::operation::BCQGather::Param::input_hidden_size.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBroadcastToShape()

ir::Shape onert::shape_inference::inferBroadcastToShape ( const ir::Shape  shp_shape,
const int32_t *  shp_buf 
)

Definition at line 277 of file ShapeInference.cc.

278{
279
280 const int num_elements = shp_shape.num_elements();
281
282 assert(num_elements != 0);
283 assert(shp_buf);
284
285 ir::Shape new_shape(num_elements);
286
287 for (int i = 0; i < num_elements; ++i)
288 {
289 assert(shp_buf[i] != 0); // It shouldn't be 0.
290 new_shape.dim(i) = shp_buf[i];
291 }
292
293 return new_shape;
294}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferConcatShape()

ir::Shape onert::shape_inference::inferConcatShape ( const Shapes in_shapes,
const ir::operation::Concat::Param param 
)

Definition at line 296 of file ShapeInference.cc.

297{
298 const int32_t concat_axis = param.axis >= 0 ? param.axis : in_shapes[0].rank() + param.axis;
299 const auto &first_in_shape = in_shapes[0];
300
301 // Check that all shapes are equal except for concat axis dimension
302 for (const auto &in_shape : in_shapes)
303 {
304 if (in_shape.rank() != first_in_shape.rank())
305 throw std::runtime_error("Rank in all input tensors should be same");
306
307 for (int64_t dim_idx = 0; dim_idx < in_shape.rank(); ++dim_idx)
308 if (!(dim_idx == concat_axis || in_shape.dim(dim_idx) == first_in_shape.dim(dim_idx)))
309 throw std::runtime_error("All tensor should have same dimension "
310 "except dimension on passed axis");
311 }
312
313 // Calculate output shape
314 ir::Shape out_shape(first_in_shape);
315 out_shape.dim(concat_axis) = 0;
316 for (const auto &in_shape : in_shapes)
317 out_shape.dim(concat_axis) += in_shape.dim(concat_axis);
318 return out_shape;
319}

References onert::ir::operation::Concat::Param::axis.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferConv2DShape()

ir::Shape onert::shape_inference::inferConv2DShape ( const ir::Shape in_shape,
const ir::Shape ker_shape,
const ir::operation::Conv2D::Param param 
)

Definition at line 321 of file ShapeInference.cc.

323{
324 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
325 throw std::runtime_error{"Conv2D: stride values must be positive"};
326
327 auto ifm_shape = in_shape.asFeature();
328
329 // Kernel format is [depth_out, kernel_height, kernel_width, depth_in]
330 auto kf_shape = ker_shape.asFeature();
331 assert(ifm_shape.C == kf_shape.C);
332
333 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(
334 ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W, param.padding, param.stride, param.dilation);
335
336 return ir::Shape{ifm_shape.N, out_h, out_w, kf_shape.N};
337}

References calcConvLikeHeightAndWidth(), onert::ir::operation::Conv2D::Param::dilation, onert::ir::Stride::horizontal, onert::ir::operation::Conv2D::Param::padding, onert::ir::operation::Conv2D::Param::stride, and onert::ir::Stride::vertical.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferDepthwiseConv2DShape()

ir::Shape onert::shape_inference::inferDepthwiseConv2DShape ( const ir::Shape in_shape,
const ir::Shape ker_shape,
const ir::operation::DepthwiseConv2D::Param param 
)

Definition at line 339 of file ShapeInference.cc.

341{
342 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
343 throw std::runtime_error{"DepthwiseConv2D: stride values must be positive"};
344
345 auto ifm_shape = in_shape.asFeature();
346
347 // Kernel format is [1, kernel_height, kernel_width, depth_out]
348 auto kf_shape = ker_shape.asFeature();
349 assert(kf_shape.C == static_cast<int32_t>(ifm_shape.C * param.multiplier));
350 assert(kf_shape.N == 1);
351
352 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(
353 ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W, param.padding, param.stride, param.dilation);
354
355 return ir::Shape{ifm_shape.N, out_h, out_w, kf_shape.C};
356}

References calcConvLikeHeightAndWidth(), onert::ir::operation::DepthwiseConv2D::Param::dilation, onert::ir::Stride::horizontal, onert::ir::operation::DepthwiseConv2D::Param::multiplier, onert::ir::operation::DepthwiseConv2D::Param::padding, onert::ir::operation::DepthwiseConv2D::Param::stride, and onert::ir::Stride::vertical.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferEltwiseShape()

ir::Shape onert::shape_inference::inferEltwiseShape ( const ir::Shape lhs_shape,
const ir::Shape rhs_shape 
)

Definition at line 121 of file ShapeInference.cc.

122{
123 return broadcastShapes(lhs_shape, rhs_shape);
124}

◆ inferExpandDimsShape()

ir::Shape onert::shape_inference::inferExpandDimsShape ( const ir::Shape in_shape,
int32_t  axis 
)

Definition at line 358 of file ShapeInference.cc.

359{
360 ir::Shape out_shape(in_shape.rank() + 1);
361
362 axis = ((axis >= 0) ? axis : /* when axis < 0 */ (out_shape.rank() + axis));
363 if (!(0 <= axis && axis <= in_shape.rank()))
364 throw std::runtime_error("axis of dim is out of range");
365
366 for (int x = 0, out_x = 0; out_x < out_shape.rank(); ++out_x)
367 {
368 if (out_x == axis)
369 out_shape.dim(out_x) = 1;
370 else
371 out_shape.dim(out_x) = in_shape.dim(x++);
372 }
373
374 return out_shape;
375}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferFillShape() [1/3]

template ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const int32_t *  shape_buf 
)

◆ inferFillShape() [2/3]

template ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const int64_t *  shape_buf 
)

◆ inferFillShape() [3/3]

template<typename T >
ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const T *  shape_buf 
)

Definition at line 377 of file ShapeInference.cc.

378{
379 ir::Shape out_shape(fill_shape.dim(0));
380
381 for (int out_x = 0; out_x < out_shape.rank(); ++out_x)
382 {
383 out_shape.dim(out_x) = static_cast<int32_t>(shape_buf[out_x]);
384 }
385
386 return out_shape;
387}

◆ inferFullyConnectedShape()

ir::Shape onert::shape_inference::inferFullyConnectedShape ( const ir::Shape in_shape,
const ir::Shape ker_shape,
bool  keep_num_dims 
)

Definition at line 393 of file ShapeInference.cc.

395{
396 assert(in_shape.rank() >= 2);
397 assert(ker_shape.rank() == 2);
398
399 const auto input_size_with_batch = in_shape.num_elements();
400 const auto num_units = ker_shape.dim(0);
401 const auto input_size = ker_shape.dim(1);
402 const auto batch_size = input_size_with_batch / input_size;
403 assert(input_size_with_batch % input_size == 0);
404
405 if (keep_num_dims)
406 {
407 assert(in_shape.dim(in_shape.rank() - 1) == input_size);
408 auto output_shape = ir::Shape(in_shape);
409 output_shape.dim(output_shape.rank() - 1) = num_units;
410 return output_shape;
411 }
412
413 return {ir::Shape({static_cast<int32_t>(batch_size), num_units})};
414}

References output_shape.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferGatherShape()

ir::Shape onert::shape_inference::inferGatherShape ( const ir::Shape input_shape,
const ir::Shape indices_shape,
int  axis,
int  rank 
)

Definition at line 463 of file ShapeInference.cc.

465{
466 ir::Shape out_shape;
467
468 const int indices_rank = indices_shape.rank();
469
470 for (int idx = 0; idx < rank; ++idx)
471 {
472 if (idx == axis)
473 {
474 for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
475 {
476 out_shape.append(indices_shape.dim(indices_idx));
477 }
478 }
479 else
480 {
481 out_shape.append(input_shape.dim(idx));
482 }
483 }
484
485 return out_shape;
486}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferOnehotShape()

ir::Shape onert::shape_inference::inferOnehotShape ( const ir::Shape input_shape,
const int  depth,
int  axis 
)

Definition at line 488 of file ShapeInference.cc.

489{
490 assert(depth >= 0);
491 const auto rank = input_shape.rank() + 1;
492 ir::Shape newShape(rank);
493
494 axis = (axis == -1) ? (rank - 1) : axis;
495
496 for (int i = 0; i < rank; ++i)
497 {
498 if (i < axis)
499 {
500 newShape.dim(i) = input_shape.dim(i);
501 }
502 else if (i == axis)
503 {
504 newShape.dim(i) = depth;
505 }
506 else
507 {
508 newShape.dim(i) = input_shape.dim(i - 1);
509 }
510 }
511
512 return newShape;
513}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPackShape()

ir::Shape onert::shape_inference::inferPackShape ( const ir::Shape input_shape,
int  axis,
int  rank,
int  num 
)

Definition at line 515 of file ShapeInference.cc.

516{
517 ir::Shape out_shape;
518 int in_idx = 0;
519
520 for (int out_idx = 0; out_idx < rank; ++out_idx)
521 {
522 if (out_idx == axis)
523 {
524 out_shape.append(num);
525 }
526 else
527 {
528 out_shape.append(input_shape.dim(in_idx++));
529 }
530 }
531
532 return out_shape;
533}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPadShape()

ir::Shape onert::shape_inference::inferPadShape ( const ir::Shape in_shape,
const int32_t *  pad_buf,
const size_t  num_pads 
)

Definition at line 535 of file ShapeInference.cc.

536{
537 assert(num_pads % 2 == 0);
538 const int32_t rank = num_pads / 2;
539
540 ir::Shape ret(rank);
541 for (int32_t i = 0; i < rank; ++i)
542 {
543 const auto before_padding = pad_buf[i * 2];
544 const auto after_padding = pad_buf[i * 2 + 1];
545
546 ret.dim(i) = in_shape.dim(i) + before_padding + after_padding;
547 }
548
549 return ret;
550}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPoolShape()

ir::Shape onert::shape_inference::inferPoolShape ( const ir::Shape in_shape,
const ir::operation::Pool2D::Param param 
)

Definition at line 552 of file ShapeInference.cc.

553{
554 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
555 throw std::runtime_error{"Pool2D: stride values must be positive"};
556
557 auto ifm_shape = in_shape.asFeature();
558 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh,
559 param.kw, param.padding, param.stride);
560 // Pooling don't change number of channels and batch size
561 return ir::Shape{ifm_shape.N, out_h, out_w, ifm_shape.C};
562}

References calcConvLikeHeightAndWidth(), onert::ir::Stride::horizontal, onert::ir::operation::Pool2D::Param::kh, onert::ir::operation::Pool2D::Param::kw, onert::ir::operation::Pool2D::Param::padding, onert::ir::operation::Pool2D::Param::stride, and onert::ir::Stride::vertical.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferRangeShape() [1/4]

template<float * >
ir::Shape onert::shape_inference::inferRangeShape ( float *  start_val,
float *  limit_val,
float *  delta_val 
)

◆ inferRangeShape() [2/4]

template ir::Shape onert::shape_inference::inferRangeShape ( float  start_val,
float  limit_val,
float  delta_val 
)

◆ inferRangeShape() [3/4]

template ir::Shape onert::shape_inference::inferRangeShape ( int  start_val,
int  limit_val,
int  delta_val 
)

◆ inferRangeShape() [4/4]

template<typename T >
ir::Shape onert::shape_inference::inferRangeShape ( start_val,
limit_val,
delta_val 
)

Definition at line 589 of file ShapeInference.cc.

590{
591 ir::Shape out_shape(static_cast<int>(1));
592
593 out_shape.dim(0) =
594 (std::is_integral<T>::value
595 ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
596 : std::ceil(std::abs((start_val - limit_val) / delta_val)));
597 return out_shape;
598}

◆ inferReduceShape()

ir::Shape onert::shape_inference::inferReduceShape ( const ir::Shape input_shape,
const std::vector< int > &  axes,
bool  keep_dims 
)

Definition at line 146 of file ShapeInference.cc.

148{
149 int num_axis = axes.size();
150 int input_num_dims = input_shape.rank();
151 if (input_num_dims == 0)
152 {
153 ir::Shape out_shape(0);
154 return out_shape;
155 }
156 if (keep_dims)
157 {
158 ir::Shape out_shape;
159 for (int idx = 0; idx < input_num_dims; ++idx)
160 {
161 bool is_axis = false;
162 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
163 {
164 if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
165 {
166 is_axis = true;
167 break;
168 }
169 }
170 if (is_axis)
171 {
172 out_shape.append(1);
173 }
174 else
175 {
176 out_shape.append(input_shape.dim(idx));
177 }
178 }
179 return out_shape;
180 }
181 else
182 {
183 // Calculates size of reducing axis.
184 for (int i = 0; i < num_axis; ++i)
185 {
186 int current = axes[i];
187 if (!(-input_num_dims <= current && current < input_num_dims))
188 throw std::runtime_error{"Invalid dim value " + std::to_string(current)};
189 if (current < 0)
190 {
191 current += input_num_dims;
192 }
193 for (int j = 0; j < i; ++j)
194 {
195 int previous = axes[j];
196 if (previous < 0)
197 {
198 previous += input_num_dims;
199 }
200 if (current == previous)
201 {
202 break;
203 }
204 }
205 }
206 // Determines output dimensions.
207 ir::Shape out_shape;
208 for (int idx = 0; idx < input_num_dims; ++idx)
209 {
210 bool is_axis = false;
211 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
212 {
213 if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
214 {
215 is_axis = true;
216 break;
217 }
218 }
219 if (!is_axis)
220 {
221 out_shape.append(input_shape.dim(idx));
222 }
223 }
224 return out_shape;
225 }
226}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferReshapeShape()

ir::Shape onert::shape_inference::inferReshapeShape ( const ir::Shape input_shape,
const int32_t *  shape_buf,
const int32_t  shape_num_elements 
)

Definition at line 604 of file ShapeInference.cc.

606{
607 ir::Shape ret(shape_num_elements);
608 int32_t flatten_dim = ir::Shape::kUnspecifiedDim;
609 auto total_num_elements = input_shape.num_elements();
610 for (int32_t i = 0; i < shape_num_elements; ++i)
611 {
612 if (shape_buf[i] < 0)
613 {
614 if (flatten_dim != ir::Shape::kUnspecifiedDim)
615 throw std::runtime_error("Reshape: 2nd param has special dim(for flatten) more than twice");
616 flatten_dim = i;
617 ret.dim(i) = 1;
618 }
619 else
620 {
621 ret.dim(i) = shape_buf[i];
622 }
623 }
624 if (flatten_dim != ir::Shape::kUnspecifiedDim)
625 ret.dim(flatten_dim) = total_num_elements / ret.num_elements();
626
627 // Check reshapable
628 if (total_num_elements != static_cast<size_t>(ret.num_elements()))
629 {
630 // Multi batch case
631 // TODO Handle multi batch case more precisely on runtime level
632 if ((ret.dim(0) == 1) &&
633 (total_num_elements == static_cast<size_t>(ret.num_elements() * input_shape.dim(0))))
634 ret.dim(0) = input_shape.dim(0);
635 else
636 throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
637 }
638
639 return ret;
640}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferResizeBilinearShape()

ir::Shape onert::shape_inference::inferResizeBilinearShape ( const ir::Shape in_shape,
const int32_t  output_height,
const int32_t  output_width 
)

Definition at line 564 of file ShapeInference.cc.

566{
567 assert(in_shape.rank() == 4);
568 if (output_height < 0)
569 {
570 throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_height = " +
571 std::to_string(output_height)};
572 }
573 if (output_width < 0)
574 {
575 throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_width = " +
576 std::to_string(output_width)};
577 }
578
579 ir::Shape ret(in_shape.rank());
580
581 ret.dim(0) = in_shape.dim(0);
582 ret.dim(1) = output_height;
583 ret.dim(2) = output_width;
584 ret.dim(3) = in_shape.dim(3);
585
586 return ret;
587}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSelectShape()

ir::Shape onert::shape_inference::inferSelectShape ( const ir::Shape input_cond_shape,
const ir::Shape input_true_shape,
const ir::Shape input_false_shape 
)

Definition at line 642 of file ShapeInference.cc.

644{
645 auto haveSameShapes = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
646 const ir::Shape &input_false_shape) {
647 if ((input_cond_shape.rank() != input_true_shape.rank()) ||
648 input_cond_shape.rank() != input_false_shape.rank())
649 {
650 return false;
651 }
652
653 int rank = input_cond_shape.rank();
654 for (int i = 0; i < rank; ++i)
655 {
656 if (input_cond_shape.dim(i) != input_true_shape.dim(i) ||
657 input_cond_shape.dim(i) != input_false_shape.dim(i))
658 {
659 return false;
660 }
661 }
662
663 return true;
664 };
665
666 auto calculateShape = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
667 const ir::Shape &input_false_shape, ir::Shape &new_shape) {
668 ir::Shape cond_shape = input_cond_shape;
669 ir::Shape true_shape = input_true_shape;
670 ir::Shape false_shape = input_false_shape;
671 int most_rank =
672 (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
673 ? cond_shape.rank()
674 : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
675
676 ir::Shape calculate_shape(most_rank);
677
678 cond_shape.extendRank(most_rank);
679 true_shape.extendRank(most_rank);
680 false_shape.extendRank(most_rank);
681
682 for (int i = 0; i < most_rank; ++i)
683 {
684 calculate_shape.dim(i) =
685 (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
686 ? cond_shape.dim(i)
687 : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
688
689 if ((cond_shape.dim(i) != calculate_shape.dim(i) && cond_shape.dim(i) != 1) ||
690 (true_shape.dim(i) != calculate_shape.dim(i) && true_shape.dim(i) != 1) ||
691 (false_shape.dim(i) != calculate_shape.dim(i) && false_shape.dim(i) != 1))
692 {
693 return false;
694 }
695 }
696
697 new_shape = calculate_shape;
698
699 return true;
700 };
701
702 bool havesame = haveSameShapes(input_cond_shape, input_true_shape, input_false_shape);
703 if (havesame)
704 {
705 return input_cond_shape;
706 }
707
708 ir::Shape new_shape;
709 bool possible = calculateShape(input_cond_shape, input_true_shape, input_false_shape, new_shape);
710
711 if (!possible)
712 {
713 throw std::runtime_error("Broadcasting is not possible.");
714 }
715
716 return new_shape;
717}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSliceShape() [1/3]

template ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const int32_t *  begins_buf,
const int32_t *  sizes_buf 
)

◆ inferSliceShape() [2/3]

template ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const int64_t *  begins_buf,
const int64_t *  sizes_buf 
)

◆ inferSliceShape() [3/3]

template<typename T >
ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const T *  begins_buf,
const T *  sizes_buf 
)

Definition at line 720 of file ShapeInference.cc.

721{
722 const uint32_t rank = input_shape.rank();
723 ir::Shape out_shape(rank);
724
725 for (uint32_t idx = 0; idx < rank; ++idx)
726 {
727 const auto input_dim = input_shape.dim(idx);
728
729 // begin is zero-based
730 auto begin = begins_buf[idx];
731 if (begin < 0)
732 throw std::runtime_error("shape inference Slice: Invalid begin.");
733
734 // size is one-based
735 auto size = sizes_buf[idx];
736 if (size < -1)
737 throw std::runtime_error("shape inference Slice: Invalid size.");
738
739 if (size == -1)
740 {
741 size = input_dim - begin;
742 }
743 else
744 {
745 if (input_dim < static_cast<int32_t>(begin + size))
746 throw std::runtime_error("shape inference Slice: Invalid begin and size.");
747 }
748 out_shape.dim(idx) = static_cast<int32_t>(size);
749 }
750
751 return out_shape;
752}
int32_t size[5]
Definition Slice.cpp:35

References begin, and size.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSpaceToBatchNDShape()

ir::Shape onert::shape_inference::inferSpaceToBatchNDShape ( const ir::Shape input_shape,
const ir::Shape block_shape_shape,
const ir::Shape padding_shape,
const int32_t *  block_shape_buf,
const int32_t *  padding_buf 
)

Definition at line 759 of file ShapeInference.cc.

763{
764 const uint32_t rank = input_shape.rank();
765 ir::Shape out_shape(rank);
766
767 // Currently, only 4D NHWC input/output op_context are supported.
768 // The 4D array need to have exactly 2 spatial dimensions.
769 // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
770 [[maybe_unused]] const int32_t kInputDimensionNum = 4;
771 [[maybe_unused]] const int32_t kBlockSizeDimensionNum = 1;
772 const int32_t kSpatialDimensionNum = 2;
773
774 assert(block_shape_shape.rank() == kBlockSizeDimensionNum);
775 assert(block_shape_shape.dim(0) == kSpatialDimensionNum);
776 assert(padding_shape.dim(0) == kSpatialDimensionNum);
777 assert(padding_shape.dim(1) == 2); // fixed, meaning left/right padding for each element
778 assert(padding_shape.rank() == 2); // fixed, meaning dimension(dim 0) and padding length(dim 1)
779
780 // Ensures the input height and width (with padding) is a multiple of block
781 // shape height and width.
782 for (int dim = 0; dim < kSpatialDimensionNum; ++dim)
783 {
784 int final_dim_size =
785 (input_shape.dim(dim + 1) + padding_buf[dim * 2] + padding_buf[dim * 2 + 1]);
786
787 assert(final_dim_size % block_shape_buf[dim] == 0);
788
789 out_shape.dim(dim + 1) = final_dim_size / block_shape_buf[dim];
790 }
791
792 const int output_batch_size = input_shape.dim(0) * block_shape_buf[0] * block_shape_buf[1];
793 const int output_channel_size = input_shape.dim(3);
794
795 out_shape.dim(0) = output_batch_size;
796 out_shape.dim(3) = output_channel_size;
797
798 return out_shape;
799}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSplitShape()

ir::Shape onert::shape_inference::inferSplitShape ( const ir::Shape  input_shape,
int  axis_value,
int  num_splits 
)

Definition at line 801 of file ShapeInference.cc.

802{
803 ir::Shape newShape(input_shape);
804
805 assert(axis_value >= 0);
806 assert(axis_value < input_shape.rank());
807
808 const int input_size = input_shape.dim(axis_value);
809 assert(input_size % num_splits == 0);
810 const int slice_size = input_size / num_splits;
811
812 newShape.dim(axis_value) = slice_size;
813
814 return newShape;
815}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSqueezeShape()

ir::Shape onert::shape_inference::inferSqueezeShape ( const ir::Shape in_shape,
const ir::operation::Squeeze::Param param 
)

Definition at line 817 of file ShapeInference.cc.

818{
819 const int ndims = param.ndim;
820 const int *squeeze_dims = param.dims;
821 bool should_squeeze[8] = {false};
822 int num_squeezed_dims = 0;
823 int shape_rank = in_shape.rank();
824 if (ndims == 0)
825 {
826 for (int idx = 0; idx < shape_rank; ++idx)
827 {
828 if (in_shape.dim(idx) == 1)
829 {
830 should_squeeze[idx] = true;
831 ++num_squeezed_dims;
832 }
833 }
834 }
835 else
836 {
837 for (int idx = 0; idx < ndims; ++idx)
838 {
839 int current = squeeze_dims[idx];
840 if (current < 0)
841 {
842 current += shape_rank;
843 }
844
845 if (!(current >= 0 && current < shape_rank && in_shape.dim(current) == 1))
846 {
847 throw std::runtime_error(
848 "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
849 }
850
851 if (!should_squeeze[current])
852 {
853 ++num_squeezed_dims;
854 }
855 should_squeeze[current] = true;
856 }
857 }
858
859 // Set output shape.
860 ir::Shape out_shape(shape_rank - num_squeezed_dims);
861 for (int in_idx = 0, out_idx = 0; in_idx < shape_rank; ++in_idx)
862 {
863 if (!should_squeeze[in_idx])
864 {
865 out_shape.dim(out_idx++) = in_shape.dim(in_idx);
866 }
867 }
868
869 return out_shape;
870}

References onert::ir::operation::Squeeze::Param::dims, and onert::ir::operation::Squeeze::Param::ndim.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferStridedSliceShape()

ir::Shape onert::shape_inference::inferStridedSliceShape ( const ir::Shape input_shape,
const StridedSliceParams op_params,
uint32_t  rank 
)

Definition at line 1024 of file ShapeInference.cc.

1026{
1027 ir::Shape out_shape;
1028
1029 for (uint32_t idx = 0; idx < rank; ++idx)
1030 {
1031 int32_t stride = op_params.strides[idx];
1032 int32_t begin = StartForAxis(op_params, input_shape, idx);
1033 int32_t end = StopForAxis(op_params, input_shape, idx, begin);
1034
1035 // When shrinking an axis, the end position does not matter (and can be
1036 // incorrect when negative indexing is used, see Issue #19260). Always use
1037 // begin + 1 to generate a length 1 slice, since begin has
1038 // already been adjusted for negative indices by StartForAxis.
1039 const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
1040 if (shrink_axis)
1041 {
1042 end = begin + 1;
1043 }
1044
1045 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
1046 dim_shape = dim_shape < 0 ? 0 : dim_shape;
1047 if (!shrink_axis)
1048 {
1049 out_shape.append(dim_shape);
1050 }
1051 }
1052
1053 return out_shape;
1054}

References begin, onert::shape_inference::StridedSliceParams::shrink_axis_mask, StartForAxis(), StopForAxis(), and onert::shape_inference::StridedSliceParams::strides.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferTileShape()

ir::Shape onert::shape_inference::inferTileShape ( const ir::Shape in_shape,
const int32_t *  multiplier_buf,
const int32_t  multiplier_size 
)

Definition at line 1056 of file ShapeInference.cc.

1058{
1059 if (multiplier_size != in_shape.rank())
1060 {
1061 throw std::runtime_error(
1062 "inferTileShape failed, input rank: " + std::to_string(in_shape.rank()) +
1063 ", bad multipliers size: " + std::to_string(multiplier_size) + "");
1064 }
1065 ir::Shape new_Shape(in_shape.rank());
1066
1067 for (int i = 0; i < in_shape.rank(); ++i)
1068 {
1069 assert(multiplier_buf[i]); // multiplier_buf[i] shuld not be 0.
1070 new_Shape.dim(i) = in_shape.dim(i) * multiplier_buf[i];
1071 }
1072 return new_Shape;
1073}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferTransposeShape()

ir::Shape onert::shape_inference::inferTransposeShape ( const ir::Shape in_shape,
const int32_t *  perm_buf,
const int32_t  rank 
)

Definition at line 1075 of file ShapeInference.cc.

1077{
1078 const auto rank = in_shape.rank();
1079 if (perm_size > rank)
1080 {
1081 throw std::runtime_error("inferTransposeShape failed, bad permutation size: " +
1082 std::to_string(perm_size));
1083 }
1084
1085 const int32_t *perm_data = perm_buf;
1086 std::vector<int32_t> regular_perm_vec;
1087 if (perm_size == 0)
1088 {
1089 // perm_data will be set to (n-1...0)
1090 regular_perm_vec.resize(rank);
1091 std::iota(regular_perm_vec.begin(), regular_perm_vec.end(), 0);
1092 std::reverse(regular_perm_vec.begin(), regular_perm_vec.end());
1093 perm_data = regular_perm_vec.data();
1094 }
1095 else
1096 {
1097 assert(rank == perm_size);
1098 }
1099
1100 ir::Shape out_shape(rank);
1101 std::vector<bool> visit_perms(rank, false);
1102 for (int idx = 0; idx < rank; idx++)
1103 {
1104 const auto perm_val = perm_data[idx];
1105 // Check invalid permutation value
1106 if (perm_val < 0 || perm_val >= rank)
1107 {
1108 throw std::runtime_error("inferTransposeShape failed, bad permutation value: " +
1109 std::to_string(perm_val));
1110 }
1111
1112 // Check duplicated permutation value
1113 if (visit_perms.at(perm_val))
1114 {
1115 throw std::runtime_error("inferTransposeShape failed, duplicated permutation value: " +
1116 std::to_string(perm_val));
1117 }
1118 visit_perms.at(perm_val) = true;
1119
1120 out_shape.dim(idx) = in_shape.dim(perm_val);
1121 }
1122 return out_shape;
1123}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferUnpackShape()

ir::Shape onert::shape_inference::inferUnpackShape ( const ir::Shape input_shape,
int  axis,
int  rank 
)

Definition at line 1125 of file ShapeInference.cc.

1126{
1127 ir::Shape out_shape;
1128
1129 for (int out_idx = 0; out_idx < rank; out_idx++)
1130 {
1131 if (out_idx != axis)
1132 {
1133 out_shape.append(input_shape.dim(out_idx));
1134 }
1135 }
1136
1137 return out_shape;
1138}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ StartForAxis()

int onert::shape_inference::StartForAxis ( const StridedSliceParams params,
const ir::Shape input_shape,
int  axis 
)

Definition at line 919 of file ShapeInference.cc.

920{
921 const auto begin_mask = params.begin_mask;
922 const auto *start_indices = params.start_indices;
923 const auto *strides = params.strides;
924 // Begin with the specified index.
925 int start = start_indices[axis];
926
927 // begin_mask override
928 if (begin_mask & 1 << axis)
929 {
930 if (strides[axis] > 0)
931 {
932 // Forward iteration - use the first element. These values will get
933 // clamped below (Note: We could have set them to 0 and axis_size-1, but
934 // use lowest() and max() to maintain symmetry with StopForAxis())
935 start = std::numeric_limits<int>::lowest();
936 }
937 else
938 {
939 // Backward iteration - use the last element.
940 start = std::numeric_limits<int>::max();
941 }
942 }
943
944 // Handle negative indices
945 int axis_size = input_shape.dim(axis);
946 if (start < 0)
947 {
948 start += axis_size;
949 }
950
951 // Clamping
952 start = Clamp(start, 0, axis_size - 1);
953
954 return start;
955}
int Clamp(const int32_t v, const int32_t lo, const int32_t hi)

References onert::shape_inference::StridedSliceParams::begin_mask, Clamp(), onert::shape_inference::StridedSliceParams::start_indices, and onert::shape_inference::StridedSliceParams::strides.

Referenced by inferStridedSliceShape().

◆ StopForAxis()

int onert::shape_inference::StopForAxis ( const StridedSliceParams params,
const ir::Shape input_shape,
int  axis,
int  start_for_axis 
)

Definition at line 962 of file ShapeInference.cc.

964{
965 const auto end_mask = params.end_mask;
966 const auto shrink_axis_mask = params.shrink_axis_mask;
967 const auto *stop_indices = params.stop_indices;
968 const auto *strides = params.strides;
969
970 // Begin with the specified index
971 const bool shrink_axis = shrink_axis_mask & (1 << axis);
972 int stop = stop_indices[axis];
973
974 // When shrinking an axis, the end position does not matter (and can be
975 // incorrect when negative indexing is used, see Issue #19260). Always use
976 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
977 // already been adjusted for negative indices.
978 if (shrink_axis)
979 {
980 stop = start_for_axis + 1;
981 }
982
983 // end_mask override
984 if (end_mask & (1 << axis))
985 {
986 if (strides[axis] > 0)
987 {
988 // Forward iteration - use the last element. These values will get
989 // clamped below
990 stop = std::numeric_limits<int>::max();
991 }
992 else
993 {
994 // Backward iteration - use the first element.
995 stop = std::numeric_limits<int>::lowest();
996 }
997 }
998
999 // Handle negative indices
1000
1001 const int axis_size = input_shape.dim(axis);
1002 if (stop < 0)
1003 {
1004 stop += axis_size;
1005 }
1006
1007 // Clamping
1008 // Because the end index points one past the last element, we need slightly
1009 // different clamping ranges depending on the direction.
1010 if (strides[axis] > 0)
1011 {
1012 // Forward iteration
1013 stop = Clamp(stop, 0, axis_size);
1014 }
1015 else
1016 {
1017 // Backward iteration
1018 stop = Clamp(stop, -1, axis_size - 1);
1019 }
1020
1021 return stop;
1022}

References Clamp(), onert::shape_inference::StridedSliceParams::end_mask, onert::shape_inference::StridedSliceParams::shrink_axis_mask, onert::shape_inference::StridedSliceParams::stop_indices, and onert::shape_inference::StridedSliceParams::strides.

Referenced by inferStridedSliceShape().