ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::shape_inference Namespace Reference

Namespaces

namespace  bcq
 

Data Structures

struct  StridedSliceParams
 

Typedefs

using Shapes = std::vector< ir::Shape >
 

Functions

ir::Shape inferArgMinMaxShape (const ir::Shape &input_shape, int axis, int rank)
 
ir::Shape inferBatchMatMulShape (const ir::Shape &lhs_shape, const ir::Shape &rhs_shape, const ir::operation::BatchMatMul::Param &param)
 
ir::Shape inferBCQFullyConnectedShape (const ir::Shape &in_shape, const ir::Shape &cluster_shape, const int32_t *cluster_buf)
 
ir::Shape inferBCQGatherShape (const ir::Shape &indices_shape, const ir::Shape &cluster_shape, const int32_t *cluster_buf, int rank, const ir::operation::BCQGather::Param &param)
 
ir::Shape inferBCQUnembeddingShape (const ir::Shape &in_shape)
 
ir::Shape inferBroadcastToShape (const ir::Shape shp_shape, const int32_t *shp_buf)
 
ir::Shape inferConcatShape (const Shapes &in_shapes, const ir::operation::Concat::Param &param)
 
ir::Shape inferConv2DShape (const ir::Shape &in_shape, const ir::Shape &ker_shape, const ir::operation::Conv2D::Param &param)
 
ir::Shape inferDepthwiseConv2DShape (const ir::Shape &in_shape, const ir::Shape &ker_shape, const ir::operation::DepthwiseConv2D::Param &param)
 
ir::Shape inferEltwiseShape (const ir::Shape &lhs_shape, const ir::Shape &rhs_shape)
 
ir::Shape inferExpandDimsShape (const ir::Shape &in_shape, int32_t axis)
 
template<typename T >
ir::Shape inferFillShape (const ir::Shape &fill_shape, const T *shape_buf)
 
ir::Shape inferFullyConnectedShape (const ir::Shape &in_shape, const ir::Shape &ker_shape, bool keep_num_dims)
 
ir::Shape inferGatherShape (const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis, int rank)
 
ir::Shape inferOnehotShape (const ir::Shape &input_shape, const int depth, int axis)
 
ir::Shape inferPackShape (const ir::Shape &input_shape, int axis, int rank, int num)
 
ir::Shape inferPadShape (const ir::Shape &in_shape, const int32_t *pad_buf, const size_t num_pads)
 
ir::Shape inferPoolShape (const ir::Shape &in_shape, const ir::operation::Pool2D::Param &param)
 
template<typename T >
ir::Shape inferRangeShape (T start_val, T limit_val, T delta_val)
 
ir::Shape inferReshapeShape (const ir::Shape &input_shape, const int32_t *shape_buf, const int32_t shape_num_elements)
 
ir::Shape inferReduceShape (const ir::Shape &input_shape, const std::vector< int > &axes, bool keep_dims)
 
template<float * >
ir::Shape inferRangeShape (float *start_val, float *limit_val, float *delta_val)
 
ir::Shape inferResizeBilinearShape (const ir::Shape &in_shape, const int32_t output_height, const int32_t output_width)
 
ir::Shape inferSelectShape (const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape, const ir::Shape &input_false_shape)
 
template<typename T >
ir::Shape inferSliceShape (const ir::Shape &input_shape, const T *begins_buf, const T *sizes_buf)
 
ir::Shape inferSpaceToBatchNDShape (const ir::Shape &input_shape, const ir::Shape &block_shape_shape, const ir::Shape &padding_shape, const int32_t *block_shape_buf, const int32_t *padding_buf)
 
ir::Shape inferSplitShape (const ir::Shape input_shape, int axis_value, int num_splits)
 
ir::Shape inferSqueezeShape (const ir::Shape &in_shape, const ir::operation::Squeeze::Param &param)
 
template<typename T >
StridedSliceParams buildStridedSliceParams (const T *begin, const T *end, const T *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
 
ir::Shape inferStridedSliceShape (const ir::Shape &input_shape, const StridedSliceParams &op_params, uint32_t rank)
 
ir::Shape inferTileShape (const ir::Shape &in_shape, const int32_t *multiplier_buf, const int32_t multiplier_size)
 
ir::Shape inferTransposeShape (const ir::Shape &in_shape, const int32_t *perm_buf, const int32_t rank)
 
ir::Shape inferUnpackShape (const ir::Shape &input_shape, int axis, int rank)
 
std::pair< int, int > calcConvLikeHeightAndWidth (const int in_h, const int in_w, const int ker_h, const int ker_w, const ir::Padding pad, const ir::Stride stride, const ir::Dilation dilation={1, 1})
 
template ir::Shape inferFillShape (const ir::Shape &fill_shape, const int32_t *shape_buf)
 
template ir::Shape inferFillShape (const ir::Shape &fill_shape, const int64_t *shape_buf)
 
template ir::Shape inferRangeShape (int start_val, int limit_val, int delta_val)
 
template ir::Shape inferRangeShape (float start_val, float limit_val, float delta_val)
 
template ir::Shape inferSliceShape (const ir::Shape &input_shape, const int32_t *begins_buf, const int32_t *sizes_buf)
 
template ir::Shape inferSliceShape (const ir::Shape &input_shape, const int64_t *begins_buf, const int64_t *sizes_buf)
 
template StridedSliceParams buildStridedSliceParams (const uint32_t *begin, const uint32_t *end, const uint32_t *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
 
int Clamp (const int v, const int lo, const int hi)
 
int StartForAxis (const StridedSliceParams &params, const ir::Shape &input_shape, int axis)
 
int StopForAxis (const StridedSliceParams &params, const ir::Shape &input_shape, int axis, int start_for_axis)
 

Typedef Documentation

◆ Shapes

using onert::shape_inference::Shapes = typedef std::vector<ir::Shape>

Definition at line 37 of file ShapeInference.h.

Function Documentation

◆ buildStridedSliceParams() [1/2]

template<typename T >
StridedSliceParams onert::shape_inference::buildStridedSliceParams ( const T *  begin,
const T *  end,
const T *  strides,
const uint32_t  begin_mask,
const uint32_t  end_mask,
const uint32_t  shrink_axis_mask,
const uint8_t  rank 
)

Definition at line 879 of file ShapeInference.cc.

882{
883 StridedSliceParams op_params;
884 op_params.start_indices_count = rank;
885 op_params.stop_indices_count = rank;
886 op_params.strides_count = rank;
887
888 for (int i = 0; i < op_params.strides_count; ++i)
889 {
890 op_params.start_indices[i] = begin[i];
891 op_params.stop_indices[i] = end[i];
892 op_params.strides[i] = strides[i];
893
894 assert(op_params.strides[i] != 0);
895 }
896
897 op_params.begin_mask = begin_mask;
898 op_params.ellipsis_mask = 0; // NYI
899 op_params.end_mask = end_mask;
900 op_params.new_axis_mask = 0; // NYI
901 op_params.shrink_axis_mask = shrink_axis_mask;
902
903 assert(sizeof(op_params.begin_mask) * 4 >= rank);
904
905 return op_params;
906}
int32_t begin[5]
Definition Slice.cpp:33

References begin, onert::shape_inference::StridedSliceParams::begin_mask, onert::shape_inference::StridedSliceParams::ellipsis_mask, onert::shape_inference::StridedSliceParams::end_mask, onert::shape_inference::StridedSliceParams::new_axis_mask, onert::shape_inference::StridedSliceParams::shrink_axis_mask, onert::shape_inference::StridedSliceParams::start_indices, onert::shape_inference::StridedSliceParams::start_indices_count, onert::shape_inference::StridedSliceParams::stop_indices, onert::shape_inference::StridedSliceParams::stop_indices_count, onert::shape_inference::StridedSliceParams::strides, and onert::shape_inference::StridedSliceParams::strides_count.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ buildStridedSliceParams() [2/2]

template StridedSliceParams onert::shape_inference::buildStridedSliceParams ( const uint32_t *  begin,
const uint32_t *  end,
const uint32_t *  strides,
const uint32_t  begin_mask,
const uint32_t  end_mask,
const uint32_t  shrink_axis_mask,
const uint8_t  rank 
)

◆ calcConvLikeHeightAndWidth()

std::pair< int, int > onert::shape_inference::calcConvLikeHeightAndWidth ( const int  in_h,
const int  in_w,
const int  ker_h,
const int  ker_w,
const ir::Padding  pad,
const ir::Stride  stride,
const ir::Dilation  dilation = {1, 1} 
)

Definition at line 90 of file ShapeInference.cc.

93 {1, 1})
94{
95 int32_t out_h = 0, out_w = 0;
96 int32_t effective_filter_w_size = (ker_w - 1) * dilation.width_factor + 1;
97 int32_t effective_filter_h_size = (ker_h - 1) * dilation.height_factor + 1;
98 switch (pad.type)
99 {
100 case ir::PaddingType::SAME:
101 out_h = ceil_div(in_h, stride.vertical);
102 out_w = ceil_div(in_w, stride.horizontal);
103 break;
104 case ir::PaddingType::VALID:
105 out_h = ceil_div(in_h - effective_filter_h_size + 1, stride.vertical);
106 out_w = ceil_div(in_w - effective_filter_w_size + 1, stride.horizontal);
107 break;
108 case ir::PaddingType::EXPLICIT:
109 out_h =
110 (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1;
111 out_w =
112 (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal + 1;
113 break;
114 default:
115 assert(false);
116 }
117
118 return {out_h, out_w};
119}
PaddingType type
Definition Padding.h:59
ExplicitPadding param
Definition Padding.h:60

Referenced by inferConv2DShape(), inferDepthwiseConv2DShape(), and inferPoolShape().

◆ Clamp()

int onert::shape_inference::Clamp ( const int  v,
const int  lo,
const int  hi 
)

Definition at line 914 of file ShapeInference.cc.

915{
916 assert(!(hi < lo));
917 if (hi < v)
918 return hi;
919 if (v < lo)
920 return lo;
921 return v;
922}

Referenced by StartForAxis(), and StopForAxis().

◆ inferArgMinMaxShape()

ir::Shape onert::shape_inference::inferArgMinMaxShape ( const ir::Shape input_shape,
int  axis,
int  rank 
)

Definition at line 126 of file ShapeInference.cc.

127{
128 if (axis < 0 || axis >= rank)
129 {
130 throw std::runtime_error("ArgMinMax shape inference: Wrong axis value " + std::to_string(axis));
131 }
132
133 ir::Shape out_shape;
134 for (int idx = 0; idx < rank; ++idx)
135 {
136 if (idx != axis)
137 {
138 int32_t input_dim = input_shape.dim(idx);
139 out_shape.append(input_dim);
140 }
141 }
142
143 return out_shape;
144}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBatchMatMulShape()

ir::Shape onert::shape_inference::inferBatchMatMulShape ( const ir::Shape lhs_shape,
const ir::Shape rhs_shape,
const ir::operation::BatchMatMul::Param param 
)

Definition at line 228 of file ShapeInference.cc.

230{
231 bool adj_x = param.adj_x;
232 bool adj_y = param.adj_y;
234
235 int output_rank = std::max(lhs_shape.rank(), rhs_shape.rank());
236
237 // Extend lhs and rhs shape
238 ir::Shape extended_lhs_shape(lhs_shape);
239 ir::Shape extended_rhs_shape(rhs_shape);
240 extended_lhs_shape.extendRank(output_rank);
241 extended_rhs_shape.extendRank(output_rank);
242
243 for (int i = 0; i < output_rank - 2; i++)
244 {
245 const int lhs_dim = extended_lhs_shape.dim(i);
246 const int rhs_dim = extended_rhs_shape.dim(i);
247 int broadcast_dim = lhs_dim;
248 if (lhs_dim != rhs_dim)
249 {
250 if (lhs_dim == 1)
251 {
252 broadcast_dim = rhs_dim;
253 }
254 else if (rhs_dim != 1)
255 {
256 throw std::runtime_error{"BatchMatMul shape inference: invalid brodcasting input shape"};
257 }
258 }
259
260 output_shape.append(broadcast_dim);
261 }
262
263 // Fill in the matmul dimensions.
264 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
265 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
266
267 output_shape.append(extended_lhs_shape.dim(lhs_rows_index));
268 output_shape.append(extended_rhs_shape.dim(rhs_cols_index));
269
270 return output_shape;
271}
const luci_interpreter::RuntimeShape output_shape

References onert::ir::operation::BatchMatMul::Param::adj_x, onert::ir::operation::BatchMatMul::Param::adj_y, and output_shape.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBCQFullyConnectedShape()

ir::Shape onert::shape_inference::inferBCQFullyConnectedShape ( const ir::Shape in_shape,
const ir::Shape cluster_shape,
const int32_t *  cluster_buf 
)

Definition at line 416 of file ShapeInference.cc.

418{
419 assert(cluster_shape.rank() == 2);
420 assert(cluster_shape.dim(1) == 2);
421
422 const auto input_size = in_shape.dim(1);
423 const auto output_size = bcq::getOutputSize(cluster_shape, cluster_buf);
424
425 return {ir::Shape({output_size, input_size})};
426}

References onert::shape_inference::bcq::getOutputSize().

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBCQGatherShape()

ir::Shape onert::shape_inference::inferBCQGatherShape ( const ir::Shape indices_shape,
const ir::Shape cluster_shape,
const int32_t *  cluster_buf,
int  rank,
const ir::operation::BCQGather::Param param 
)

Definition at line 428 of file ShapeInference.cc.

431{
432 ir::Shape out_shape;
433 ir::Shape in_original_shape;
434
435 assert(cluster_shape.rank() == 2);
436 assert(cluster_shape.dim(1) == 2);
437
438 auto hidden_size = param.input_hidden_size;
439 auto axis = param.axis;
440
441 in_original_shape.append(bcq::getOutputSize(cluster_shape, cluster_buf));
442 in_original_shape.append(hidden_size);
443
444 const int indices_rank = indices_shape.rank();
445 for (int idx = 0; idx < rank; ++idx)
446 {
447 if (idx == (int)axis)
448 {
449 for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
450 {
451 out_shape.append(indices_shape.dim(indices_idx));
452 }
453 }
454 else
455 {
456 out_shape.append(in_original_shape.dim(idx));
457 }
458 }
459
460 return out_shape;
461}

References onert::ir::operation::BCQGather::Param::axis, onert::shape_inference::bcq::getOutputSize(), and onert::ir::operation::BCQGather::Param::input_hidden_size.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBCQUnembeddingShape()

ir::Shape onert::shape_inference::inferBCQUnembeddingShape ( const ir::Shape in_shape)

Definition at line 463 of file ShapeInference.cc.

464{
465 return {ir::Shape({1, in_shape.dim(1)})};
466}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferBroadcastToShape()

ir::Shape onert::shape_inference::inferBroadcastToShape ( const ir::Shape  shp_shape,
const int32_t *  shp_buf 
)

Definition at line 277 of file ShapeInference.cc.

278{
279
280 const int num_elements = shp_shape.num_elements();
281
282 assert(num_elements != 0);
283 assert(shp_buf);
284
285 ir::Shape new_shape(num_elements);
286
287 for (int i = 0; i < num_elements; ++i)
288 {
289 assert(shp_buf[i] != 0); // It shouldn't be 0.
290 new_shape.dim(i) = shp_buf[i];
291 }
292
293 return new_shape;
294}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferConcatShape()

ir::Shape onert::shape_inference::inferConcatShape ( const Shapes in_shapes,
const ir::operation::Concat::Param param 
)

Definition at line 296 of file ShapeInference.cc.

297{
298 const int32_t concat_axis = param.axis >= 0 ? param.axis : in_shapes[0].rank() + param.axis;
299 const auto &first_in_shape = in_shapes[0];
300
301 // Check that all shapes are equal except for concat axis dimension
302 for (const auto &in_shape : in_shapes)
303 {
304 if (in_shape.rank() != first_in_shape.rank())
305 throw std::runtime_error("Rank in all input tensors should be same");
306
307 for (int64_t dim_idx = 0; dim_idx < in_shape.rank(); ++dim_idx)
308 if (!(dim_idx == concat_axis || in_shape.dim(dim_idx) == first_in_shape.dim(dim_idx)))
309 throw std::runtime_error("All tensor should have same dimension "
310 "except dimension on passed axis");
311 }
312
313 // Calculate output shape
314 ir::Shape out_shape(first_in_shape);
315 out_shape.dim(concat_axis) = 0;
316 for (const auto &in_shape : in_shapes)
317 out_shape.dim(concat_axis) += in_shape.dim(concat_axis);
318 return out_shape;
319}

References onert::ir::operation::Concat::Param::axis.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferConv2DShape()

ir::Shape onert::shape_inference::inferConv2DShape ( const ir::Shape in_shape,
const ir::Shape ker_shape,
const ir::operation::Conv2D::Param param 
)

Definition at line 321 of file ShapeInference.cc.

323{
324 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
325 throw std::runtime_error{"Conv2D: stride values must be positive"};
326
327 auto ifm_shape = in_shape.asFeature();
328
329 // Kernel format is [depth_out, kernel_height, kernel_width, depth_in]
330 auto kf_shape = ker_shape.asFeature();
331 assert(ifm_shape.C == kf_shape.C);
332
333 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(
334 ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W, param.padding, param.stride, param.dilation);
335
336 return ir::Shape{ifm_shape.N, out_h, out_w, kf_shape.N};
337}

References calcConvLikeHeightAndWidth(), onert::ir::operation::Conv2D::Param::dilation, onert::ir::Stride::horizontal, onert::ir::operation::Conv2D::Param::padding, onert::ir::operation::Conv2D::Param::stride, and onert::ir::Stride::vertical.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferDepthwiseConv2DShape()

ir::Shape onert::shape_inference::inferDepthwiseConv2DShape ( const ir::Shape in_shape,
const ir::Shape ker_shape,
const ir::operation::DepthwiseConv2D::Param param 
)

Definition at line 339 of file ShapeInference.cc.

341{
342 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
343 throw std::runtime_error{"DepthwiseConv2D: stride values must be positive"};
344
345 auto ifm_shape = in_shape.asFeature();
346
347 // Kernel format is [1, kernel_height, kernel_width, depth_out]
348 auto kf_shape = ker_shape.asFeature();
349 assert(kf_shape.C == static_cast<int32_t>(ifm_shape.C * param.multiplier));
350 assert(kf_shape.N == 1);
351
352 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(
353 ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W, param.padding, param.stride, param.dilation);
354
355 return ir::Shape{ifm_shape.N, out_h, out_w, kf_shape.C};
356}

References calcConvLikeHeightAndWidth(), onert::ir::operation::DepthwiseConv2D::Param::dilation, onert::ir::Stride::horizontal, onert::ir::operation::DepthwiseConv2D::Param::multiplier, onert::ir::operation::DepthwiseConv2D::Param::padding, onert::ir::operation::DepthwiseConv2D::Param::stride, and onert::ir::Stride::vertical.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferEltwiseShape()

ir::Shape onert::shape_inference::inferEltwiseShape ( const ir::Shape lhs_shape,
const ir::Shape rhs_shape 
)

Definition at line 121 of file ShapeInference.cc.

122{
123 return broadcastShapes(lhs_shape, rhs_shape);
124}

◆ inferExpandDimsShape()

ir::Shape onert::shape_inference::inferExpandDimsShape ( const ir::Shape in_shape,
int32_t  axis 
)

Definition at line 358 of file ShapeInference.cc.

359{
360 ir::Shape out_shape(in_shape.rank() + 1);
361
362 axis = ((axis >= 0) ? axis : /* when axis < 0 */ (out_shape.rank() + axis));
363 if (!(0 <= axis && axis <= in_shape.rank()))
364 throw std::runtime_error("axis of dim is out of range");
365
366 for (int x = 0, out_x = 0; out_x < out_shape.rank(); ++out_x)
367 {
368 if (out_x == axis)
369 out_shape.dim(out_x) = 1;
370 else
371 out_shape.dim(out_x) = in_shape.dim(x++);
372 }
373
374 return out_shape;
375}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferFillShape() [1/3]

template ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const int32_t *  shape_buf 
)

◆ inferFillShape() [2/3]

template ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const int64_t *  shape_buf 
)

◆ inferFillShape() [3/3]

template<typename T >
ir::Shape onert::shape_inference::inferFillShape ( const ir::Shape fill_shape,
const T *  shape_buf 
)

Definition at line 377 of file ShapeInference.cc.

378{
379 ir::Shape out_shape(fill_shape.dim(0));
380
381 for (int out_x = 0; out_x < out_shape.rank(); ++out_x)
382 {
383 out_shape.dim(out_x) = static_cast<int32_t>(shape_buf[out_x]);
384 }
385
386 return out_shape;
387}

◆ inferFullyConnectedShape()

ir::Shape onert::shape_inference::inferFullyConnectedShape ( const ir::Shape in_shape,
const ir::Shape ker_shape,
bool  keep_num_dims 
)

Definition at line 393 of file ShapeInference.cc.

395{
396 assert(in_shape.rank() >= 2);
397 assert(ker_shape.rank() == 2);
398
399 const auto input_size_with_batch = in_shape.num_elements();
400 const auto num_units = ker_shape.dim(0);
401 const auto input_size = ker_shape.dim(1);
402 const auto batch_size = input_size_with_batch / input_size;
403 assert(input_size_with_batch % input_size == 0);
404
405 if (keep_num_dims)
406 {
407 assert(in_shape.dim(in_shape.rank() - 1) == input_size);
408 auto output_shape = ir::Shape(in_shape);
409 output_shape.dim(output_shape.rank() - 1) = num_units;
410 return output_shape;
411 }
412
413 return {ir::Shape({static_cast<int32_t>(batch_size), num_units})};
414}

References output_shape.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferGatherShape()

ir::Shape onert::shape_inference::inferGatherShape ( const ir::Shape input_shape,
const ir::Shape indices_shape,
int  axis,
int  rank 
)

Definition at line 468 of file ShapeInference.cc.

470{
471 ir::Shape out_shape;
472
473 const int indices_rank = indices_shape.rank();
474
475 for (int idx = 0; idx < rank; ++idx)
476 {
477 if (idx == axis)
478 {
479 for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
480 {
481 out_shape.append(indices_shape.dim(indices_idx));
482 }
483 }
484 else
485 {
486 out_shape.append(input_shape.dim(idx));
487 }
488 }
489
490 return out_shape;
491}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferOnehotShape()

ir::Shape onert::shape_inference::inferOnehotShape ( const ir::Shape input_shape,
const int  depth,
int  axis 
)

Definition at line 493 of file ShapeInference.cc.

494{
495 assert(depth >= 0);
496 const auto rank = input_shape.rank() + 1;
497 ir::Shape newShape(rank);
498
499 axis = (axis == -1) ? (rank - 1) : axis;
500
501 for (int i = 0; i < rank; ++i)
502 {
503 if (i < axis)
504 {
505 newShape.dim(i) = input_shape.dim(i);
506 }
507 else if (i == axis)
508 {
509 newShape.dim(i) = depth;
510 }
511 else
512 {
513 newShape.dim(i) = input_shape.dim(i - 1);
514 }
515 }
516
517 return newShape;
518}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPackShape()

ir::Shape onert::shape_inference::inferPackShape ( const ir::Shape input_shape,
int  axis,
int  rank,
int  num 
)

Definition at line 520 of file ShapeInference.cc.

521{
522 ir::Shape out_shape;
523 int in_idx = 0;
524
525 for (int out_idx = 0; out_idx < rank; ++out_idx)
526 {
527 if (out_idx == axis)
528 {
529 out_shape.append(num);
530 }
531 else
532 {
533 out_shape.append(input_shape.dim(in_idx++));
534 }
535 }
536
537 return out_shape;
538}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPadShape()

ir::Shape onert::shape_inference::inferPadShape ( const ir::Shape in_shape,
const int32_t *  pad_buf,
const size_t  num_pads 
)

Definition at line 540 of file ShapeInference.cc.

541{
542 assert(num_pads % 2 == 0);
543 const int32_t rank = num_pads / 2;
544
545 ir::Shape ret(rank);
546 for (int32_t i = 0; i < rank; ++i)
547 {
548 const auto before_padding = pad_buf[i * 2];
549 const auto after_padding = pad_buf[i * 2 + 1];
550
551 ret.dim(i) = in_shape.dim(i) + before_padding + after_padding;
552 }
553
554 return ret;
555}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferPoolShape()

ir::Shape onert::shape_inference::inferPoolShape ( const ir::Shape in_shape,
const ir::operation::Pool2D::Param param 
)

Definition at line 557 of file ShapeInference.cc.

558{
559 if (param.stride.horizontal == 0 || param.stride.vertical == 0)
560 throw std::runtime_error{"Pool2D: stride values must be positive"};
561
562 auto ifm_shape = in_shape.asFeature();
563 const auto [out_h, out_w] = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh,
564 param.kw, param.padding, param.stride);
565 // Pooling don't change number of channels and batch size
566 return ir::Shape{ifm_shape.N, out_h, out_w, ifm_shape.C};
567}

References calcConvLikeHeightAndWidth(), onert::ir::Stride::horizontal, onert::ir::operation::Pool2D::Param::kh, onert::ir::operation::Pool2D::Param::kw, onert::ir::operation::Pool2D::Param::padding, onert::ir::operation::Pool2D::Param::stride, and onert::ir::Stride::vertical.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferRangeShape() [1/4]

template<float * >
ir::Shape onert::shape_inference::inferRangeShape ( float *  start_val,
float *  limit_val,
float *  delta_val 
)

◆ inferRangeShape() [2/4]

template ir::Shape onert::shape_inference::inferRangeShape ( float  start_val,
float  limit_val,
float  delta_val 
)

◆ inferRangeShape() [3/4]

template ir::Shape onert::shape_inference::inferRangeShape ( int  start_val,
int  limit_val,
int  delta_val 
)

◆ inferRangeShape() [4/4]

template<typename T >
ir::Shape onert::shape_inference::inferRangeShape ( start_val,
limit_val,
delta_val 
)

Definition at line 594 of file ShapeInference.cc.

595{
596 ir::Shape out_shape(static_cast<int>(1));
597
598 out_shape.dim(0) =
599 (std::is_integral<T>::value
600 ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
601 : std::ceil(std::abs((start_val - limit_val) / delta_val)));
602 return out_shape;
603}

◆ inferReduceShape()

ir::Shape onert::shape_inference::inferReduceShape ( const ir::Shape input_shape,
const std::vector< int > &  axes,
bool  keep_dims 
)

Definition at line 146 of file ShapeInference.cc.

148{
149 int num_axis = axes.size();
150 int input_num_dims = input_shape.rank();
151 if (input_num_dims == 0)
152 {
153 ir::Shape out_shape(0);
154 return out_shape;
155 }
156 if (keep_dims)
157 {
158 ir::Shape out_shape;
159 for (int idx = 0; idx < input_num_dims; ++idx)
160 {
161 bool is_axis = false;
162 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
163 {
164 if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
165 {
166 is_axis = true;
167 break;
168 }
169 }
170 if (is_axis)
171 {
172 out_shape.append(1);
173 }
174 else
175 {
176 out_shape.append(input_shape.dim(idx));
177 }
178 }
179 return out_shape;
180 }
181 else
182 {
183 // Calculates size of reducing axis.
184 for (int i = 0; i < num_axis; ++i)
185 {
186 int current = axes[i];
187 if (!(-input_num_dims <= current && current < input_num_dims))
188 throw std::runtime_error{"Invalid dim value " + std::to_string(current)};
189 if (current < 0)
190 {
191 current += input_num_dims;
192 }
193 for (int j = 0; j < i; ++j)
194 {
195 int previous = axes[j];
196 if (previous < 0)
197 {
198 previous += input_num_dims;
199 }
200 if (current == previous)
201 {
202 break;
203 }
204 }
205 }
206 // Determines output dimensions.
207 ir::Shape out_shape;
208 for (int idx = 0; idx < input_num_dims; ++idx)
209 {
210 bool is_axis = false;
211 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
212 {
213 if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
214 {
215 is_axis = true;
216 break;
217 }
218 }
219 if (!is_axis)
220 {
221 out_shape.append(input_shape.dim(idx));
222 }
223 }
224 return out_shape;
225 }
226}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferReshapeShape()

ir::Shape onert::shape_inference::inferReshapeShape ( const ir::Shape input_shape,
const int32_t *  shape_buf,
const int32_t  shape_num_elements 
)

Definition at line 609 of file ShapeInference.cc.

611{
612 ir::Shape ret(shape_num_elements);
613 int32_t flatten_dim = ir::Shape::kUnspecifiedDim;
614 auto total_num_elements = input_shape.num_elements();
615 for (int32_t i = 0; i < shape_num_elements; ++i)
616 {
617 if (shape_buf[i] < 0)
618 {
619 if (flatten_dim != ir::Shape::kUnspecifiedDim)
620 throw std::runtime_error("Reshape: 2nd param has special dim(for flatten) more than twice");
621 flatten_dim = i;
622 ret.dim(i) = 1;
623 }
624 else
625 {
626 ret.dim(i) = shape_buf[i];
627 }
628 }
629 if (flatten_dim != ir::Shape::kUnspecifiedDim)
630 ret.dim(flatten_dim) = total_num_elements / ret.num_elements();
631
632 // Check reshapable
633 if (total_num_elements != static_cast<size_t>(ret.num_elements()))
634 {
635 // Multi batch case
636 // TODO Handle multi batch case more precisely on runtime level
637 if ((ret.dim(0) == 1) &&
638 (total_num_elements == static_cast<size_t>(ret.num_elements() * input_shape.dim(0))))
639 ret.dim(0) = input_shape.dim(0);
640 else
641 throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
642 }
643
644 return ret;
645}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferResizeBilinearShape()

ir::Shape onert::shape_inference::inferResizeBilinearShape ( const ir::Shape in_shape,
const int32_t  output_height,
const int32_t  output_width 
)

Definition at line 569 of file ShapeInference.cc.

571{
572 assert(in_shape.rank() == 4);
573 if (output_height < 0)
574 {
575 throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_height = " +
576 std::to_string(output_height)};
577 }
578 if (output_width < 0)
579 {
580 throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_width = " +
581 std::to_string(output_width)};
582 }
583
584 ir::Shape ret(in_shape.rank());
585
586 ret.dim(0) = in_shape.dim(0);
587 ret.dim(1) = output_height;
588 ret.dim(2) = output_width;
589 ret.dim(3) = in_shape.dim(3);
590
591 return ret;
592}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSelectShape()

ir::Shape onert::shape_inference::inferSelectShape ( const ir::Shape input_cond_shape,
const ir::Shape input_true_shape,
const ir::Shape input_false_shape 
)

Definition at line 647 of file ShapeInference.cc.

649{
650 auto haveSameShapes = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
651 const ir::Shape &input_false_shape) {
652 if ((input_cond_shape.rank() != input_true_shape.rank()) ||
653 input_cond_shape.rank() != input_false_shape.rank())
654 {
655 return false;
656 }
657
658 int rank = input_cond_shape.rank();
659 for (int i = 0; i < rank; ++i)
660 {
661 if (input_cond_shape.dim(i) != input_true_shape.dim(i) ||
662 input_cond_shape.dim(i) != input_false_shape.dim(i))
663 {
664 return false;
665 }
666 }
667
668 return true;
669 };
670
671 auto calculateShape = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
672 const ir::Shape &input_false_shape, ir::Shape &new_shape) {
673 ir::Shape cond_shape = input_cond_shape;
674 ir::Shape true_shape = input_true_shape;
675 ir::Shape false_shape = input_false_shape;
676 int most_rank =
677 (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
678 ? cond_shape.rank()
679 : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
680
681 ir::Shape calculate_shape(most_rank);
682
683 cond_shape.extendRank(most_rank);
684 true_shape.extendRank(most_rank);
685 false_shape.extendRank(most_rank);
686
687 for (int i = 0; i < most_rank; ++i)
688 {
689 calculate_shape.dim(i) =
690 (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
691 ? cond_shape.dim(i)
692 : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
693
694 if ((cond_shape.dim(i) != calculate_shape.dim(i) && cond_shape.dim(i) != 1) ||
695 (true_shape.dim(i) != calculate_shape.dim(i) && true_shape.dim(i) != 1) ||
696 (false_shape.dim(i) != calculate_shape.dim(i) && false_shape.dim(i) != 1))
697 {
698 return false;
699 }
700 }
701
702 new_shape = calculate_shape;
703
704 return true;
705 };
706
707 bool havesame = haveSameShapes(input_cond_shape, input_true_shape, input_false_shape);
708 if (havesame)
709 {
710 return input_cond_shape;
711 }
712
713 ir::Shape new_shape;
714 bool possible = calculateShape(input_cond_shape, input_true_shape, input_false_shape, new_shape);
715
716 if (!possible)
717 {
718 throw std::runtime_error("Broadcasting is not possible.");
719 }
720
721 return new_shape;
722}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSliceShape() [1/3]

template ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const int32_t *  begins_buf,
const int32_t *  sizes_buf 
)

◆ inferSliceShape() [2/3]

template ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const int64_t *  begins_buf,
const int64_t *  sizes_buf 
)

◆ inferSliceShape() [3/3]

template<typename T >
ir::Shape onert::shape_inference::inferSliceShape ( const ir::Shape input_shape,
const T *  begins_buf,
const T *  sizes_buf 
)

Definition at line 725 of file ShapeInference.cc.

726{
727 const uint32_t rank = input_shape.rank();
728 ir::Shape out_shape(rank);
729
730 for (uint32_t idx = 0; idx < rank; ++idx)
731 {
732 const auto input_dim = input_shape.dim(idx);
733
734 // begin is zero-based
735 auto begin = begins_buf[idx];
736 if (begin < 0)
737 throw std::runtime_error("shape inference Slice: Invalid begin.");
738
739 // size is one-based
740 auto size = sizes_buf[idx];
741 if (size < -1)
742 throw std::runtime_error("shape inference Slice: Invalid size.");
743
744 if (size == -1)
745 {
746 size = input_dim - begin;
747 }
748 else
749 {
750 if (input_dim < static_cast<int32_t>(begin + size))
751 throw std::runtime_error("shape inference Slice: Invalid begin and size.");
752 }
753 out_shape.dim(idx) = static_cast<int32_t>(size);
754 }
755
756 return out_shape;
757}
int32_t size[5]
Definition Slice.cpp:35

References begin, and size.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSpaceToBatchNDShape()

ir::Shape onert::shape_inference::inferSpaceToBatchNDShape ( const ir::Shape input_shape,
const ir::Shape block_shape_shape,
const ir::Shape padding_shape,
const int32_t *  block_shape_buf,
const int32_t *  padding_buf 
)

Definition at line 764 of file ShapeInference.cc.

768{
769 const uint32_t rank = input_shape.rank();
770 ir::Shape out_shape(rank);
771
772 // Currently, only 4D NHWC input/output op_context are supported.
773 // The 4D array need to have exactly 2 spatial dimensions.
774 // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
775 [[maybe_unused]] const int32_t kInputDimensionNum = 4;
776 [[maybe_unused]] const int32_t kBlockSizeDimensionNum = 1;
777 const int32_t kSpatialDimensionNum = 2;
778
779 assert(block_shape_shape.rank() == kBlockSizeDimensionNum);
780 assert(block_shape_shape.dim(0) == kSpatialDimensionNum);
781 assert(padding_shape.dim(0) == kSpatialDimensionNum);
782 assert(padding_shape.dim(1) == 2); // fixed, meaning left/right padding for each element
783 assert(padding_shape.rank() == 2); // fixed, meaning dimension(dim 0) and padding length(dim 1)
784
785 // Ensures the input height and width (with padding) is a multiple of block
786 // shape height and width.
787 for (int dim = 0; dim < kSpatialDimensionNum; ++dim)
788 {
789 int final_dim_size =
790 (input_shape.dim(dim + 1) + padding_buf[dim * 2] + padding_buf[dim * 2 + 1]);
791
792 assert(final_dim_size % block_shape_buf[dim] == 0);
793
794 out_shape.dim(dim + 1) = final_dim_size / block_shape_buf[dim];
795 }
796
797 const int output_batch_size = input_shape.dim(0) * block_shape_buf[0] * block_shape_buf[1];
798 const int output_channel_size = input_shape.dim(3);
799
800 out_shape.dim(0) = output_batch_size;
801 out_shape.dim(3) = output_channel_size;
802
803 return out_shape;
804}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSplitShape()

ir::Shape onert::shape_inference::inferSplitShape ( const ir::Shape  input_shape,
int  axis_value,
int  num_splits 
)

Definition at line 806 of file ShapeInference.cc.

807{
808 ir::Shape newShape(input_shape);
809
810 assert(axis_value >= 0);
811 assert(axis_value < input_shape.rank());
812
813 const int input_size = input_shape.dim(axis_value);
814 assert(input_size % num_splits == 0);
815 const int slice_size = input_size / num_splits;
816
817 newShape.dim(axis_value) = slice_size;
818
819 return newShape;
820}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferSqueezeShape()

ir::Shape onert::shape_inference::inferSqueezeShape ( const ir::Shape in_shape,
const ir::operation::Squeeze::Param param 
)

Definition at line 822 of file ShapeInference.cc.

823{
824 const int ndims = param.ndim;
825 const int *squeeze_dims = param.dims;
826 bool should_squeeze[8] = {false};
827 int num_squeezed_dims = 0;
828 int shape_rank = in_shape.rank();
829 if (ndims == 0)
830 {
831 for (int idx = 0; idx < shape_rank; ++idx)
832 {
833 if (in_shape.dim(idx) == 1)
834 {
835 should_squeeze[idx] = true;
836 ++num_squeezed_dims;
837 }
838 }
839 }
840 else
841 {
842 for (int idx = 0; idx < ndims; ++idx)
843 {
844 int current = squeeze_dims[idx];
845 if (current < 0)
846 {
847 current += shape_rank;
848 }
849
850 if (!(current >= 0 && current < shape_rank && in_shape.dim(current) == 1))
851 {
852 throw std::runtime_error(
853 "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
854 }
855
856 if (!should_squeeze[current])
857 {
858 ++num_squeezed_dims;
859 }
860 should_squeeze[current] = true;
861 }
862 }
863
864 // Set output shape.
865 ir::Shape out_shape(shape_rank - num_squeezed_dims);
866 for (int in_idx = 0, out_idx = 0; in_idx < shape_rank; ++in_idx)
867 {
868 if (!should_squeeze[in_idx])
869 {
870 out_shape.dim(out_idx++) = in_shape.dim(in_idx);
871 }
872 }
873
874 return out_shape;
875}

References onert::ir::operation::Squeeze::Param::dims, and onert::ir::operation::Squeeze::Param::ndim.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferStridedSliceShape()

ir::Shape onert::shape_inference::inferStridedSliceShape ( const ir::Shape input_shape,
const StridedSliceParams op_params,
uint32_t  rank 
)

Definition at line 1029 of file ShapeInference.cc.

1031{
1032 ir::Shape out_shape;
1033
1034 for (uint32_t idx = 0; idx < rank; ++idx)
1035 {
1036 int32_t stride = op_params.strides[idx];
1037 int32_t begin = StartForAxis(op_params, input_shape, idx);
1038 int32_t end = StopForAxis(op_params, input_shape, idx, begin);
1039
1040 // When shrinking an axis, the end position does not matter (and can be
1041 // incorrect when negative indexing is used, see Issue #19260). Always use
1042 // begin + 1 to generate a length 1 slice, since begin has
1043 // already been adjusted for negative indices by StartForAxis.
1044 const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
1045 if (shrink_axis)
1046 {
1047 end = begin + 1;
1048 }
1049
1050 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
1051 dim_shape = dim_shape < 0 ? 0 : dim_shape;
1052 if (!shrink_axis)
1053 {
1054 out_shape.append(dim_shape);
1055 }
1056 }
1057
1058 return out_shape;
1059}

References begin, onert::shape_inference::StridedSliceParams::shrink_axis_mask, StartForAxis(), StopForAxis(), and onert::shape_inference::StridedSliceParams::strides.

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferTileShape()

ir::Shape onert::shape_inference::inferTileShape ( const ir::Shape in_shape,
const int32_t *  multiplier_buf,
const int32_t  multiplier_size 
)

Definition at line 1061 of file ShapeInference.cc.

1063{
1064 if (multiplier_size != in_shape.rank())
1065 {
1066 throw std::runtime_error(
1067 "inferTileShape failed, input rank: " + std::to_string(in_shape.rank()) +
1068 ", bad multipliers size: " + std::to_string(multiplier_size) + "");
1069 }
1070 ir::Shape new_Shape(in_shape.rank());
1071
1072 for (int i = 0; i < in_shape.rank(); ++i)
1073 {
1074 assert(multiplier_buf[i]); // multiplier_buf[i] shuld not be 0.
1075 new_Shape.dim(i) = in_shape.dim(i) * multiplier_buf[i];
1076 }
1077 return new_Shape;
1078}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferTransposeShape()

ir::Shape onert::shape_inference::inferTransposeShape ( const ir::Shape in_shape,
const int32_t *  perm_buf,
const int32_t  rank 
)

Definition at line 1080 of file ShapeInference.cc.

1082{
1083 const auto rank = in_shape.rank();
1084 if (perm_size > rank)
1085 {
1086 throw std::runtime_error("inferTransposeShape failed, bad permutation size: " +
1087 std::to_string(perm_size));
1088 }
1089
1090 const int32_t *perm_data = perm_buf;
1091 std::vector<int32_t> regular_perm_vec;
1092 if (perm_size == 0)
1093 {
1094 // perm_data will be set to (n-1...0)
1095 regular_perm_vec.resize(rank);
1096 std::iota(regular_perm_vec.begin(), regular_perm_vec.end(), 0);
1097 std::reverse(regular_perm_vec.begin(), regular_perm_vec.end());
1098 perm_data = regular_perm_vec.data();
1099 }
1100 else
1101 {
1102 assert(rank == perm_size);
1103 }
1104
1105 ir::Shape out_shape(rank);
1106 std::vector<bool> visit_perms(rank, false);
1107 for (int idx = 0; idx < rank; idx++)
1108 {
1109 const auto perm_val = perm_data[idx];
1110 // Check invalid permutation value
1111 if (perm_val < 0 || perm_val >= rank)
1112 {
1113 throw std::runtime_error("inferTransposeShape failed, bad permutation value: " +
1114 std::to_string(perm_val));
1115 }
1116
1117 // Check duplicated permutation value
1118 if (visit_perms.at(perm_val))
1119 {
1120 throw std::runtime_error("inferTransposeShape failed, duplicated permutation value: " +
1121 std::to_string(perm_val));
1122 }
1123 visit_perms.at(perm_val) = true;
1124
1125 out_shape.dim(idx) = in_shape.dim(perm_val);
1126 }
1127 return out_shape;
1128}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ inferUnpackShape()

ir::Shape onert::shape_inference::inferUnpackShape ( const ir::Shape input_shape,
int  axis,
int  rank 
)

Definition at line 1130 of file ShapeInference.cc.

1131{
1132 ir::Shape out_shape;
1133
1134 for (int out_idx = 0; out_idx < rank; out_idx++)
1135 {
1136 if (out_idx != axis)
1137 {
1138 out_shape.append(input_shape.dim(out_idx));
1139 }
1140 }
1141
1142 return out_shape;
1143}

Referenced by onert::exec::DynamicShapeInferer::visit().

◆ StartForAxis()

int onert::shape_inference::StartForAxis ( const StridedSliceParams params,
const ir::Shape input_shape,
int  axis 
)

Definition at line 924 of file ShapeInference.cc.

925{
926 const auto begin_mask = params.begin_mask;
927 const auto *start_indices = params.start_indices;
928 const auto *strides = params.strides;
929 // Begin with the specified index.
930 int start = start_indices[axis];
931
932 // begin_mask override
933 if (begin_mask & 1 << axis)
934 {
935 if (strides[axis] > 0)
936 {
937 // Forward iteration - use the first element. These values will get
938 // clamped below (Note: We could have set them to 0 and axis_size-1, but
939 // use lowest() and max() to maintain symmetry with StopForAxis())
940 start = std::numeric_limits<int>::lowest();
941 }
942 else
943 {
944 // Backward iteration - use the last element.
945 start = std::numeric_limits<int>::max();
946 }
947 }
948
949 // Handle negative indices
950 int axis_size = input_shape.dim(axis);
951 if (start < 0)
952 {
953 start += axis_size;
954 }
955
956 // Clamping
957 start = Clamp(start, 0, axis_size - 1);
958
959 return start;
960}
int Clamp(const int32_t v, const int32_t lo, const int32_t hi)

References onert::shape_inference::StridedSliceParams::begin_mask, Clamp(), onert::shape_inference::StridedSliceParams::start_indices, and onert::shape_inference::StridedSliceParams::strides.

Referenced by inferStridedSliceShape().

◆ StopForAxis()

int onert::shape_inference::StopForAxis ( const StridedSliceParams params,
const ir::Shape input_shape,
int  axis,
int  start_for_axis 
)

Definition at line 967 of file ShapeInference.cc.

969{
970 const auto end_mask = params.end_mask;
971 const auto shrink_axis_mask = params.shrink_axis_mask;
972 const auto *stop_indices = params.stop_indices;
973 const auto *strides = params.strides;
974
975 // Begin with the specified index
976 const bool shrink_axis = shrink_axis_mask & (1 << axis);
977 int stop = stop_indices[axis];
978
979 // When shrinking an axis, the end position does not matter (and can be
980 // incorrect when negative indexing is used, see Issue #19260). Always use
981 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
982 // already been adjusted for negative indices.
983 if (shrink_axis)
984 {
985 stop = start_for_axis + 1;
986 }
987
988 // end_mask override
989 if (end_mask & (1 << axis))
990 {
991 if (strides[axis] > 0)
992 {
993 // Forward iteration - use the last element. These values will get
994 // clamped below
995 stop = std::numeric_limits<int>::max();
996 }
997 else
998 {
999 // Backward iteration - use the first element.
1000 stop = std::numeric_limits<int>::lowest();
1001 }
1002 }
1003
1004 // Handle negative indices
1005
1006 const int axis_size = input_shape.dim(axis);
1007 if (stop < 0)
1008 {
1009 stop += axis_size;
1010 }
1011
1012 // Clamping
1013 // Because the end index points one past the last element, we need slightly
1014 // different clamping ranges depending on the direction.
1015 if (strides[axis] > 0)
1016 {
1017 // Forward iteration
1018 stop = Clamp(stop, 0, axis_size);
1019 }
1020 else
1021 {
1022 // Backward iteration
1023 stop = Clamp(stop, -1, axis_size - 1);
1024 }
1025
1026 return stop;
1027}

References Clamp(), onert::shape_inference::StridedSliceParams::end_mask, onert::shape_inference::StridedSliceParams::shrink_axis_mask, onert::shape_inference::StridedSliceParams::stop_indices, and onert::shape_inference::StridedSliceParams::strides.

Referenced by inferStridedSliceShape().