ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::compiler::ShapeValidator Class Reference

#include <ShapeValidator.h>

Collaboration diagram for onert::compiler::ShapeValidator:

Public Member Functions

 ShapeValidator (void)=delete
 
 ShapeValidator (const ir::Graph &graph)
 
 ShapeValidator (const ShapeValidator &)=delete
 
 ShapeValidator (ShapeValidator &&)=delete
 
 ~ShapeValidator ()=default
 
ShapeValidatoroperator= (const ShapeValidator &)=delete
 
ShapeValidatoroperator= (ShapeValidator &&)=delete
 
void operator() ()
 
void visit (const ir::operation::BatchMatMul &node) override
 
void visit (const ir::operation::BatchToSpaceND &node) override
 
void visit (const ir::operation::BCQFullyConnected &node) override
 
void visit (const ir::operation::BCQGather &node) override
 
void visit (const ir::operation::Conv2D &node) override
 
void visit (const ir::operation::Comparison &node) override
 
void visit (const ir::operation::DepthwiseConv2D &node) override
 
void visit (const ir::operation::FullyConnected &node) override
 
void visit (const ir::operation::Softmax &node) override
 
void visit (const ir::operation::InstanceNorm &node) override
 
void visit (const ir::operation::Permute &node) override
 
void visit (const ir::operation::Pool2D &node) override
 
void visit (const ir::operation::Reduce &node) override
 
void visit (const ir::operation::Transpose &node) override
 
void visit (const ir::operation::RNN &node) override
 
void visit (const ir::operation::SpaceToBatchND &node) override
 
void visit (const ir::operation::SpaceToDepth &node) override
 
void visit (const ir::operation::ElementwiseActivation &node) override
 
void visit (const ir::operation::ElementwiseBinary &node) override
 
void visit (const ir::operation::ElementwiseUnary &node) override
 
void visit (const ir::operation::EmbeddingLookup &node) override
 
void visit (const ir::operation::ExpandDims &node) override
 
void visit (const ir::operation::HashtableLookup &node) override
 
void visit (const ir::operation::TransposeConv &node) override
 
void visit (const ir::operation::Gather &node) override
 
void visit (const ir::operation::DepthToSpace &node) override
 
void visit (const ir::operation::Pack &node) override
 
void visit (const ir::operation::LSTM &node) override
 
void visit (const ir::operation::L2Normalization &node) override
 
void visit (const ir::operation::Unpack &node) override
 
void visit (const ir::operation::Pad &node) override
 
void visit (const ir::operation::Select &node) override
 
void visit (const ir::operation::StridedSlice &node) override
 
void visit (const ir::operation::Split &node) override
 
void visit (const ir::operation::Shape &node) override
 
void visit (const ir::operation::ResizeBilinear &node) override
 
void visit (const ir::operation::Reverse &node) override
 
void visit (const ir::operation::If &node) override
 
void visit (const ir::operation::While &node) override
 
void visit (const ir::operation::SquaredDifference &node) override
 
void visit (const ir::operation::Tile &node) override
 
void visit (const ir::operation::Range &node) override
 
void visit (const ir::operation::MatrixBandPart &node) override
 
void visit (const ir::operation::LogSoftmax &node) override
 
void visit (const ir::operation::RmsNorm &node) override
 
void visit (const ir::operation::RoPE &node) override
 
- Public Member Functions inherited from onert::ir::OperationVisitor
virtual ~OperationVisitor ()=default
 

Detailed Description

Definition at line 32 of file ShapeValidator.h.

Constructor & Destructor Documentation

◆ ShapeValidator() [1/4]

onert::compiler::ShapeValidator::ShapeValidator ( void  )
delete

◆ ShapeValidator() [2/4]

onert::compiler::ShapeValidator::ShapeValidator ( const ir::Graph graph)

Definition at line 34 of file ShapeValidator.cc.

34: _graph{graph} {}

◆ ShapeValidator() [3/4]

onert::compiler::ShapeValidator::ShapeValidator ( const ShapeValidator )
delete

◆ ShapeValidator() [4/4]

onert::compiler::ShapeValidator::ShapeValidator ( ShapeValidator &&  )
delete

◆ ~ShapeValidator()

onert::compiler::ShapeValidator::~ShapeValidator ( )
default

Member Function Documentation

◆ operator()()

void onert::compiler::ShapeValidator::operator() ( )

Definition at line 49 of file ShapeValidator.cc.

50{
51 _graph.operations().iterate(
52 [&](const ir::OperationIndex &, const ir::IOperation &node) { node.accept(*this); });
53}
const Operations & operations() const override
Definition Graph.h:112
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
::onert::util::Index< uint32_t, OperationIndexTag > OperationIndex
Definition Index.h:30

References onert::ir::IOperation::accept(), onert::util::ObjectManager< Index, Object >::iterate(), and onert::ir::Graph::operations().

◆ operator=() [1/2]

ShapeValidator & onert::compiler::ShapeValidator::operator= ( const ShapeValidator )
delete

◆ operator=() [2/2]

ShapeValidator & onert::compiler::ShapeValidator::operator= ( ShapeValidator &&  )
delete

◆ visit() [1/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BatchMatMul node)
override

Definition at line 55 of file ShapeValidator.cc.

56{
57 const auto &operands = _graph.operands();
58 const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
59 const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
60 const auto out_index{node.getOutputs().at(0)};
61
62 if (operands.at(out_index).info().isDynamic())
63 return;
64
65 OP_REQUIRES(operands.at(lhs_index).shape().rank() <= 4);
66 OP_REQUIRES(operands.at(rhs_index).shape().rank() <= 4);
67 OP_REQUIRES(operands.at(lhs_index).shape().rank() >= 2);
68 OP_REQUIRES(operands.at(rhs_index).shape().rank() >= 2);
69}
#define OP_REQUIRES(EXP)
const Operands & operands() const override
Definition Graph.h:110

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BatchMatMul::LHS, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::BatchMatMul::RHS.

◆ visit() [2/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BatchToSpaceND node)
override

Definition at line 71 of file ShapeValidator.cc.

72{
73 const auto &operands = _graph.operands();
74 const auto ofm_index{node.getOutputs().at(0)};
75 if (operands.at(ofm_index).info().isDynamic())
76 return;
77
78 const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
79 const auto block_size_index{
81
82 const auto input_shape = operands.at(ifm_index).shape().asFeature();
83 const auto output_shape = operands.at(ofm_index).shape().asFeature();
84
85 // All requirement as per NNAPI specification.
86 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
87 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
88 OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
89
90 OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
91
92 if (node.getInputs().size() != 2)
93 {
94 const auto crops_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::CROPS_DATA)};
95 OP_REQUIRES(operands.at(crops_index).shape().rank() == 2);
96 OP_REQUIRES(operands.at(crops_index).shape().dim(0) ==
97 (operands.at(ifm_index).shape().rank() - 2));
98 OP_REQUIRES(operands.at(crops_index).shape().dim(1) == 2);
99 }
100
101 OP_REQUIRES(input_shape.C == output_shape.C);
102}
const Object & at(const Index &index) const
Get the object that is associated with the given index.
const luci_interpreter::RuntimeShape output_shape

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::BatchToSpaceND::BLOCK_SIZE, onert::ir::operation::BatchToSpaceND::CROPS_DATA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BatchToSpaceND::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::OperandIndexSequence::size().

◆ visit() [3/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BCQFullyConnected node)
override

Definition at line 104 of file ShapeValidator.cc.

105{
106 const auto &operands = _graph.operands();
107 const auto ofm_index{node.getOutputs().at(0)};
108 if (operands.at(ofm_index).info().isDynamic())
109 return;
110
111 const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
112 const auto weight_scales_index{
114 const auto weight_binary_index{
116 const auto weight_cluster_index{
118 const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)};
119
120 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 2);
121 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 2);
122 OP_REQUIRES(operands.at(weight_scales_index).shape().rank() == 1);
123 OP_REQUIRES(operands.at(weight_binary_index).shape().rank() == 2);
124 OP_REQUIRES(operands.at(weight_cluster_index).shape().rank() == 2);
125
126 OP_REQUIRES(operands.at(ifm_index).shape().dim(1) == operands.at(ofm_index).shape().dim(1));
127
128 OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(0) > 0);
129 OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(1) == 2);
130
131 // more shape validation will be done inside kernel.
132
133 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
134}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::BCQFullyConnected::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BCQFullyConnected::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::BCQFullyConnected::WEIGHTS_BINARY, onert::ir::operation::BCQFullyConnected::WEIGHTS_CLUSTERS, and onert::ir::operation::BCQFullyConnected::WEIGHTS_SCALES.

◆ visit() [4/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BCQGather node)
override

Definition at line 136 of file ShapeValidator.cc.

137{
138 const auto &operands = _graph.operands();
139 const auto ofm_index{node.getOutputs().at(0)};
140 if (operands.at(ofm_index).info().isDynamic())
141 return;
142
143 const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
144 const auto input_binary_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
145 const auto input_scales_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_SCALES)};
146 const auto input_clusters_index{
148
149 OP_REQUIRES(operands.at(indices_index).shape().rank() <=
150 2); // TODO : support rank up to 4 or more
151 OP_REQUIRES(operands.at(input_binary_index).shape().rank() == 2);
152 OP_REQUIRES(operands.at(input_scales_index).shape().rank() == 1);
153 OP_REQUIRES(operands.at(input_clusters_index).shape().rank() == 2);
154
155 OP_REQUIRES(operands.at(input_clusters_index).shape().dim(0) > 0);
156 OP_REQUIRES(operands.at(input_clusters_index).shape().dim(1) == 2);
157
158 // more shape validation will be done inside kernel.
159}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BCQGather::INDICES, onert::ir::operation::BCQGather::INPUT_BINARY, onert::ir::operation::BCQGather::INPUT_CLUSTERS, onert::ir::operation::BCQGather::INPUT_SCALES, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [5/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Comparison node)
override

Definition at line 178 of file ShapeValidator.cc.

179{
180 // TODO Shape validation of comparison
181}

◆ visit() [6/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Conv2D node)
override

Definition at line 161 of file ShapeValidator.cc.

162{
163 const auto &operands = _graph.operands();
164 const auto ofm_index{node.getOutputs().at(0)};
165 if (operands.at(ofm_index).info().isDynamic())
166 return;
167
168 const auto ifm_index{node.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
169 const auto ker_index{node.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
170 const auto bias_index{node.getInputs().at(ir::operation::Conv2D::Input::BIAS)};
171
172 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
173 OP_REQUIRES(operands.at(ker_index).shape().rank() == 4);
174 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
175 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
176}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Conv2D::INPUT, onert::ir::operation::Conv2D::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [7/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::DepthToSpace node)
override

Definition at line 570 of file ShapeValidator.cc.

571{
572 const auto &operands = _graph.operands();
573 int32_t block_size = node.param().block_size;
574
575 // shape check
576 const auto output_index{node.getOutputs().at(0)};
577 if (operands.at(output_index).info().isDynamic())
578 return;
579
580 const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
581
582 const auto output_shape = operands.at(output_index).shape().asFeature();
583 const auto input_shape = operands.at(input_index).shape().asFeature();
584
585 OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
586 OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
587
588 {
589 OP_REQUIRES(output_shape.N == input_shape.N);
590 OP_REQUIRES(output_shape.H == input_shape.H * block_size);
591 OP_REQUIRES(output_shape.W == input_shape.W * block_size);
592 OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
593 OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
594 }
595}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::DepthToSpace::Param::block_size, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthToSpace::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::DepthToSpace::param().

◆ visit() [8/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::DepthwiseConv2D node)
override

Definition at line 183 of file ShapeValidator.cc.

184{
185 const auto &operands = _graph.operands();
186 const auto ofm_index{node.getOutputs().at(0)};
187 if (operands.at(ofm_index).info().isDynamic())
188 return;
189
190 const auto ifm_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::INPUT)};
191 const auto ker_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::KERNEL)};
192 const auto bias_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::BIAS)};
193
194 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
195 OP_REQUIRES(operands.at(ker_index).shape().rank() == 4);
196 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
197 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
198}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::DepthwiseConv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthwiseConv2D::INPUT, onert::ir::operation::DepthwiseConv2D::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [9/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseActivation node)
override

Definition at line 428 of file ShapeValidator.cc.

428{ checkUnaryOp(node); }

◆ visit() [10/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseBinary node)
override

Definition at line 430 of file ShapeValidator.cc.

431{
432 // TODO Shape validation of ElementwiseBinary
433}

◆ visit() [11/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseUnary node)
override

Definition at line 435 of file ShapeValidator.cc.

436{
437 const auto &operands = _graph.operands();
438 const auto output_index{node.getOutputs().at(0)};
439 const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
440
441 if (operands.at(output_index).info().isDynamic())
442 return;
443
444 OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
445}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ElementwiseUnary::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [12/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::EmbeddingLookup node)
override

Definition at line 447 of file ShapeValidator.cc.

448{
449 const auto &operands = _graph.operands();
450 const auto output_index{node.getOutputs().at(0)};
451 const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
452 const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
453
454 const auto &output_obj = operands.at(output_index);
455 const auto &lookups_obj = operands.at(lookups_index);
456 const auto &values_obj = operands.at(values_index);
457
458 // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
459 // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
460 {
461 if (operands.at(output_index).info().isDynamic())
462 return;
463
464 const auto &output_shape = output_obj.shape();
465 const auto &lookups_shape = lookups_obj.shape();
466 const auto &values_shape = values_obj.shape();
467
468 OP_REQUIRES(lookups_shape.rank() == 1);
469 OP_REQUIRES(values_shape.rank() >= 2);
470
471 // output should be a n-D tensor with the same rank and shape as the values tensor, except for
472 // the first dimension which has the same size as lookups' only dimension.
473 OP_REQUIRES(output_shape.rank() == values_shape.rank());
474 OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
475 for (int n = 1; n < output_shape.rank(); ++n)
476 {
477 OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
478 }
479 }
480}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::EmbeddingLookup::LOOKUPS, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::EmbeddingLookup::VALUES.

◆ visit() [13/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ExpandDims node)
override

Definition at line 482 of file ShapeValidator.cc.

483{
484 const auto &operands = _graph.operands();
485 const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
486
487 if (operands.at(axis_index).info().isDynamic())
488 return;
489 OP_REQUIRES(operands.at(axis_index).shape().rank() <= 1);
490}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::ExpandDims::AXIS, onert::ir::Operation::getInputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [14/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::FullyConnected node)
override

Definition at line 200 of file ShapeValidator.cc.

201{
202 const auto &operands = _graph.operands();
203 const auto ofm_index{node.getOutputs().at(0)};
204 if (operands.at(ofm_index).info().isDynamic())
205 return;
206
207 const auto ifm_index{node.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
208 const auto ker_index{node.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
209 const auto bias_index{node.getInputs().at(ir::operation::FullyConnected::Input::BIAS)};
210
211 OP_REQUIRES(operands.at(ifm_index).shape().rank() >= 2);
212 OP_REQUIRES(operands.at(ker_index).shape().rank() == 2);
213 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
214}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::FullyConnected::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::FullyConnected::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::FullyConnected::WEIGHT.

◆ visit() [15/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Gather node)
override

Definition at line 552 of file ShapeValidator.cc.

553{
554 const auto &operands = _graph.operands();
555 const auto ofm_index{node.getOutputs().at(0)};
556 if (operands.at(ofm_index).info().isDynamic())
557 return;
558
559 const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
560 const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
561
562 const auto &ifm_shape = operands.at(ifm_index).shape();
563 const auto &indices_shape = operands.at(indices_index).shape();
564 const auto &ofm_shape = operands.at(ofm_index).shape();
565
566 // Since gather implementation is general enough, we do not restrict max rank
567 OP_REQUIRES(ifm_shape.rank() + indices_shape.rank() - 1 == ofm_shape.rank());
568}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Gather::INDICES, onert::ir::operation::Gather::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [16/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::HashtableLookup node)
override

Definition at line 492 of file ShapeValidator.cc.

493{
494 const auto &operands = _graph.operands();
495 const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
496 const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
497 const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
498 const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
499
500 const auto &output_obj = operands.at(output_index);
501 const auto &lookups_obj = operands.at(lookups_index);
502 const auto &keys_obj = operands.at(keys_index);
503 const auto &values_obj = operands.at(values_index);
504
505 if (operands.at(output_index).info().isDynamic())
506 return;
507
508 const auto &output_shape = output_obj.shape();
509 const auto &lookups_shape = lookups_obj.shape();
510 const auto &keys_shape = keys_obj.shape();
511 const auto &values_shape = values_obj.shape();
512
513 OP_REQUIRES(values_shape.rank() == output_shape.rank());
514 OP_REQUIRES(lookups_shape.rank() == 1);
515 OP_REQUIRES(keys_shape.rank() == 1);
516 OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
517 OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
518}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::HashtableLookup::KEYS, onert::ir::operation::HashtableLookup::LOOKUPS, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::HashtableLookup::OUTPUT, output_shape, and onert::ir::operation::HashtableLookup::VALUES.

◆ visit() [17/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::If node)
override

Definition at line 1003 of file ShapeValidator.cc.

1004{
1005 // TODO Add to validate with subgraphs
1006}

◆ visit() [18/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::InstanceNorm node)
override

Definition at line 228 of file ShapeValidator.cc.

229{
230 const auto &operands = _graph.operands();
231 const auto ofm_index{node.getOutputs().at(0)};
232 if (operands.at(ofm_index).info().isDynamic())
233 return;
234
235 const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
236 const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
237 const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
238
239 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
240 OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
241 OP_REQUIRES(operands.at(gamma_index).shape().rank() == 1);
242 OP_REQUIRES(operands.at(beta_index).shape().rank() == 1);
243}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::InstanceNorm::BETA, onert::ir::operation::InstanceNorm::GAMMA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::InstanceNorm::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [19/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::L2Normalization node)
override

Definition at line 874 of file ShapeValidator.cc.

875{
876 const auto &operands = _graph.operands();
877 const auto ofm_index{node.getOutputs().at(0)};
878 if (operands.at(ofm_index).info().isDynamic())
879 return;
880
881 const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
882
883 auto ifm_shape = operands.at(ifm_index).shape();
884 auto ofm_shape = operands.at(ofm_index).shape();
885
886 OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
887
888 for (auto i = 0; i < ifm_shape.rank(); i++)
889 {
890 OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
891 }
892}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::L2Normalization::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [20/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::LogSoftmax node)
override

Definition at line 1109 of file ShapeValidator.cc.

1110{
1111 const auto &operands = _graph.operands();
1112 const auto output_index{node.getOutputs().at(0)};
1113 if (operands.at(output_index).info().isDynamic())
1114 return;
1115
1116 const auto input_index{node.getInputs().at(0)};
1117
1118 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
1119}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [21/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::LSTM node)
override

Definition at line 619 of file ShapeValidator.cc.

620{
621 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
622 // TODO Support dynamic rnn
623 const auto &operands = _graph.operands();
624 const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
625 if (operands.at(output_index).info().isDynamic())
626 return;
627
628 const auto scratch_buffer_index{
629 node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; // Optional
630 const auto output_state_out_index{
631 node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; // Optional
632 const auto cell_state_out_index{
633 node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; // Optional
634
635 const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
636 const auto input_to_input_weights_index{
637 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // Optional
638 const auto input_to_forget_weights_index{
640 const auto input_to_cell_weights_index{
642 const auto input_to_output_weights_index{
644 const auto recurrent_to_input_weights_index{
645 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // Optional
646 const auto recurrent_to_forget_weights_index{
648 const auto recurrent_to_cell_weights_index{
650 const auto recurrent_to_output_weights_index{
652 const auto cell_to_input_weights_index{
653 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // Optional
654 const auto cell_to_forget_weights_index{
655 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // Optional
656 const auto cell_to_output_weights_index{
657 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // Optional
658 const auto input_gate_bias_index{
659 node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; // Optional
660 const auto forget_gate_bias_index{
662 const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
663 const auto output_gate_bias_index{
665 const auto projection_weights_index{
666 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // Optional
667 const auto projection_bias_index{
668 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // Optional
669 const auto output_state_in_index{
671 const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
672
673 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
674 for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
675 {
676 OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
677 operands.at(output_index).shape().dim(i));
678 }
679 OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
680 operands.at(output_index).shape().rank() == 3) &&
681 (operands.at(input_index).shape().rank() == 2 ||
682 operands.at(input_index).shape().rank() == 3) &&
683 (!operands.exist(input_to_input_weights_index) ||
684 operands.at(input_to_input_weights_index).shape().rank() == 2) &&
685 operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
686 operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
687 operands.at(input_to_output_weights_index).shape().rank() == 2 &&
688 (!operands.exist(recurrent_to_input_weights_index) ||
689 operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
690 operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
691 operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
692 operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
693 (!operands.exist(projection_weights_index) ||
694 operands.at(projection_weights_index).shape().rank() == 2) &&
695 operands.at(output_state_in_index).shape().rank() == 2 &&
696 operands.at(cell_state_in_index).shape().rank() == 2);
697
698 OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
699 operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
700 (!operands.exist(cell_to_forget_weights_index) ||
701 operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
702 (!operands.exist(cell_to_output_weights_index) ||
703 operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
704 (!operands.exist(input_gate_bias_index) ||
705 operands.at(input_gate_bias_index).shape().rank() == 1) &&
706 operands.at(forget_gate_bias_index).shape().rank() == 1 &&
707 operands.at(cell_bias_index).shape().rank() == 1 &&
708 operands.at(output_gate_bias_index).shape().rank() == 1 &&
709 (!operands.exist(projection_bias_index) ||
710 operands.at(projection_bias_index).shape().rank() == 1));
711
712 // CIFG assertion
713 OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
714 (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
715 operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
716 (!operands.exist(recurrent_to_input_weights_index) ||
717 (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
718 operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
719 (!operands.exist(input_gate_bias_index) ||
720 operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
721 (!operands.exist(cell_to_input_weights_index) ||
722 operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
723 ((operands.exist(input_to_input_weights_index) &&
724 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
725 operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
726 (operands.exist(recurrent_to_input_weights_index) &&
727 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
728 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
729 (operands.exist(input_gate_bias_index) &&
730 operands.at(input_gate_bias_index).shape().dim(0) != 0)));
731
732 // Peephole assertion
733 OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
734 operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
735 (!operands.exist(cell_to_output_weights_index) ||
736 operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
737 ((operands.exist(cell_to_forget_weights_index) &&
738 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
739 (operands.exist(cell_to_output_weights_index) &&
740 operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
741
742 bool has_input_to_input_weights =
743 operands.exist(input_to_input_weights_index) &&
744 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
745 operands.at(input_to_input_weights_index).shape().dim(1) != 0);
746 bool has_recurrent_to_input_weights =
747 operands.exist(recurrent_to_input_weights_index) &&
748 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
749 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
750 bool has_input_gate_bias =
751 operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
752 bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
753 operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
754 bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
755 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
756 bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
757 operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
758 bool has_projection_weights = operands.exist(projection_weights_index) &&
759 (operands.at(projection_weights_index).shape().dim(0) != 0 &&
760 operands.at(projection_weights_index).shape().dim(1) != 0);
761 bool has_projection_bias =
762 operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
763
764 // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
765 // true: no CIFG
766 // false: CIFG
767 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
768
769 // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
770 // true: peephole
771 // false: no peephole
772 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
773
774 // NOTE The projection weights may have data but the projection bias may not.
775 bool has_projection_param = has_projection_weights;
776
777 const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
778 ? operands.at(input_index).shape().dim(1)
779 : operands.at(input_index).shape().dim(0);
780 OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
781 batch_size == operands.at(cell_state_in_index).shape().dim(0));
782
783 const auto input_size =
784 operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
785 OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
786 input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
787 input_size == operands.at(input_to_output_weights_index).shape().dim(1));
788
789 const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
790 OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
791 num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
792 num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
793 num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
794 num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
795 num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
796 num_units == operands.at(cell_bias_index).shape().dim(0) &&
797 num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
798 num_units == operands.at(cell_state_in_index).shape().dim(1));
799
800 const auto output_size =
801 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
802 OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
803 output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
804 output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
805 output_size == operands.at(output_state_in_index).shape().dim(1));
806
807 if (has_cifg_param)
808 {
809 OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
811 num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
812 num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
813 ((operands.exist(cell_to_input_weights_index) &&
814 num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
815 (!operands.exist(cell_to_input_weights_index) ||
816 operands.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
817 num_units == operands.at(input_gate_bias_index).shape().dim(0));
818 OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
819 OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
820 has_input_gate_bias);
821 if (has_cell_to_input_weights)
822 {
823 // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
824 OP_REQUIRES(has_peephole_param);
825 }
826 if (operands.exist(scratch_buffer_index))
827 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
828 }
829 else
830 {
831 if (operands.exist(scratch_buffer_index))
832 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
833 }
834
835 if (has_peephole_param)
836 {
837 OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
838 num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
839 (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
840 operands.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
841 }
842
843 if (has_projection_param)
844 {
845 OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
846 OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
847 if (has_projection_bias)
848 {
849 OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
850 }
851 }
852
853 if (operands.exist(scratch_buffer_index))
854 {
855 OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
856 OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
857 }
858
859 if (operands.exist(output_state_out_index))
860 {
861 OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
862 OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
863 OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
864 }
865
866 if (operands.exist(cell_state_out_index))
867 {
868 OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
869 OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
870 OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
871 }
872}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::LSTM::CELL_BIAS, onert::ir::operation::LSTM::CELL_STATE_IN, onert::ir::operation::LSTM::CELL_STATE_OUT, onert::ir::operation::LSTM::CELL_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::CELL_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::CELL_TO_OUTPUT_WEIGHTS, onert::ir::operation::LSTM::FORGET_GATE_BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::LSTM::INPUT, onert::ir::operation::LSTM::INPUT_GATE_BIAS, onert::ir::operation::LSTM::INPUT_TO_CELL_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_OUTPUT_WEIGHTS, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::LSTM::OUTPUT, onert::ir::operation::LSTM::OUTPUT_GATE_BIAS, onert::ir::operation::LSTM::OUTPUT_STATE_IN, onert::ir::operation::LSTM::OUTPUT_STATE_OUT, onert::ir::operation::LSTM::param(), onert::ir::operation::LSTM::PROJECTION_BIAS, onert::ir::operation::LSTM::PROJECTION_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_CELL_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_OUTPUT_WEIGHTS, onert::ir::operation::LSTM::SCRATCH_BUFFER, and onert::ir::operation::LSTM::Param::time_major.

◆ visit() [22/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::MatrixBandPart node)
override

Definition at line 1090 of file ShapeValidator.cc.

1091{
1092 const auto &operands = _graph.operands();
1093 const auto output_index{node.getOutputs().at(0)};
1094 const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
1095 const auto num_lower_index{
1097 const auto num_upper_index{
1099
1100 // Check for dimension constraints
1101 if (operands.at(output_index).info().isDynamic())
1102 return;
1103
1104 OP_REQUIRES(operands.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix
1105 OP_REQUIRES(operands.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
1106 OP_REQUIRES(operands.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
1107}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::MatrixBandPart::INPUT, onert::ir::operation::MatrixBandPart::NUM_LOWER_DIAG, onert::ir::operation::MatrixBandPart::NUM_UPPER_DIAG, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [23/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pack node)
override

Definition at line 597 of file ShapeValidator.cc.

598{
599 const auto &operands = _graph.operands();
600 const auto axis{node.param().axis};
601 const auto output_index{node.getOutputs().at(0)};
602 if (operands.at(output_index).info().isDynamic())
603 return;
604
605 // shape check
606 const auto &output_shape = operands.at(output_index).shape();
607 const auto output_rank = static_cast<int32_t>(output_shape.rank());
608
609 const auto input1_index{node.getInputs().at(0)};
610 const auto &input_shape = operands.at(input1_index).shape();
611
612 OP_REQUIRES(axis >= -output_rank && axis < output_rank);
613 for (const auto &index : node.getInputs())
614 {
615 OP_REQUIRES(input_shape == operands.at(index).shape());
616 }
617}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Pack::Param::axis, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::Pack::param().

◆ visit() [24/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pad node)
override

Definition at line 910 of file ShapeValidator.cc.

911{
912 const auto &operands = _graph.operands();
913 const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
914 OP_REQUIRES(operands.at(pad_index).typeInfo().type() == ir::DataType::INT32);
915
916 const auto output_index{node.getInputs().at(0)};
917 if (operands.at(output_index).info().isDynamic())
918 return;
919
920 const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
921
922 const auto &pad_shape = operands.at(pad_index).shape();
923 const auto input_rank = static_cast<int32_t>(operands.at(input_index).shape().rank());
924
925 OP_REQUIRES(pad_shape.rank() == 2);
926 OP_REQUIRES(pad_shape.dim(0) == input_rank);
927 OP_REQUIRES(pad_shape.dim(1) == 2);
928 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
929}
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
uint32_t rank(void) const
Definition TensorShape.h:35
loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleNode *paddings)

References onert::ir::OperandIndexSequence::at(), loco::TensorShape::dim(), onert::ir::Operation::getInputs(), onert::ir::operation::Pad::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::Pad::PAD, and loco::TensorShape::rank().

◆ visit() [25/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Permute node)
override

Definition at line 257 of file ShapeValidator.cc.

258{
259 const auto &operands = _graph.operands();
260 const auto output_index{node.getOutputs().at(0)};
261 if (operands.at(output_index).info().isDynamic())
262 return;
263
264 const auto input_index{node.getInputs().at(0)};
265
266 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
267}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [26/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pool2D node)
override

Definition at line 245 of file ShapeValidator.cc.

246{
247 const auto &operands = _graph.operands();
248 const auto ofm_index{node.getOutputs().at(0)};
249 if (operands.at(ofm_index).info().isDynamic())
250 return;
251
252 const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
253
254 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
255}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Pool2D::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [27/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Range node)
override

Definition at line 1073 of file ShapeValidator.cc.

1074{
1075 const auto &operands = _graph.operands();
1076 const auto output_index{node.getOutputs().at(0)};
1077 const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
1078 const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
1079 const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
1080
1081 // Check for dimension constraints
1082 if (operands.at(output_index).info().isDynamic())
1083 return;
1084
1085 OP_REQUIRES(operands.at(start_index).shape().rank() == 0);
1086 OP_REQUIRES(operands.at(limit_index).shape().rank() == 0);
1087 OP_REQUIRES(operands.at(delta_index).shape().rank() == 0);
1088}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Range::DELTA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Range::LIMIT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Range::START.

◆ visit() [28/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Reduce node)
override

Definition at line 269 of file ShapeValidator.cc.

270{
271 const auto &operands = _graph.operands();
272 const auto output_index{node.getOutputs().at(0)};
273 if (operands.at(output_index).info().isDynamic())
274 return;
275
276 const auto &input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
277 const auto &input_shape = operands.at(input_index).shape();
278 const auto &output_shape = operands.at(output_index).shape();
279
280 OP_REQUIRES(input_shape.rank() <= 4);
281 OP_REQUIRES(output_shape.rank() <= input_shape.rank());
282
283 // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
284 // supports cases reducing height and width or reducing depth.
285 // TODO We have to support all cases of dimensions up to 4.
286 // For correct permuting, we have to set output's shape to be equal in dimension position of the
287 // input. But the positions of the same dimensions in the input and output may be set differently.
288 // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
289 // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
290 // extend it in 4 dimensions, it should be {1,1,3,5}.
291 // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
292 // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
293 // next operation is not desired.
294 if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
295 {
296 if (output_shape.rank() == 2)
297 {
298 // Reducing HW
299 OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
300 input_shape.dim(3) == output_shape.dim(1));
301 }
302 else if (output_shape.rank() == 3)
303 {
304 // Reducing C or
305 // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
307 (input_shape.dim(0) == output_shape.dim(0) && input_shape.dim(1) == output_shape.dim(1) &&
308 input_shape.dim(2) == output_shape.dim(2)) ||
309 (input_shape.dim(0) == output_shape.dim(0) &&
310 (input_shape.dim(1) == output_shape.dim(1) || input_shape.dim(2) == output_shape.dim(1)) &&
311 input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
312 }
313 }
314}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reduce::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and output_shape.

◆ visit() [29/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ResizeBilinear node)
override

Definition at line 978 of file ShapeValidator.cc.

979{
980 const auto &operands = _graph.operands();
981 const auto output_index{node.getOutputs().at(0)};
982 const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
983
984 if (operands.at(output_index).info().isDynamic())
985 {
986 return;
987 }
988 OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
989 OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
990}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ResizeBilinear::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [30/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Reverse node)
override

Definition at line 992 of file ShapeValidator.cc.

993{
994 const auto &operands = _graph.operands();
995 const auto output_index{node.getOutputs().at(0)};
996 const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
997
998 if (operands.at(output_index).info().isDynamic())
999 return;
1000 OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
1001}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reverse::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [31/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RmsNorm node)
override

Definition at line 1121 of file ShapeValidator.cc.

1122{
1123 const auto &operands = _graph.operands();
1124 const auto ofm_index{node.getOutputs().at(0)};
1125 if (operands.at(ofm_index).info().isDynamic())
1126 return;
1127
1128 const auto ifm_index{node.getInputs().at(ir::operation::RmsNorm::Input::INPUT)};
1129 const auto gamma_index{node.getInputs().at(ir::operation::RmsNorm::Input::GAMMA)};
1130
1131 const auto &ifm_shape = operands.at(ifm_index).shape();
1132 const auto &ofm_shape = operands.at(ofm_index).shape();
1133 const auto &gamma_shape = operands.at(gamma_index).shape();
1134
1135 OP_REQUIRES(ifm_shape.rank() == 3 || ifm_shape.rank() == 4);
1136 OP_REQUIRES(ifm_shape == ofm_shape);
1137 OP_REQUIRES(gamma_shape.rank() == 1);
1138 OP_REQUIRES((gamma_shape.dim(0) == 1) ||
1139 (gamma_shape.dim(0) == ifm_shape.dim(ifm_shape.rank() - 1)));
1140}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RmsNorm::GAMMA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RmsNorm::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [32/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RNN node)
override

Definition at line 335 of file ShapeValidator.cc.

336{
337 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
338 // TODO Support dynamic rnn
339 const auto &operands = _graph.operands();
340 const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
341 if (operands.at(output_index).info().isDynamic())
342 return;
343
344 const auto hidden_state_out_index{
346
347 const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
348 const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
349 const auto recurrent_weights_index{
351 const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
352 const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
353
354 const auto batch_size = operands.at(output_index).shape().dim(0);
355 const auto num_units = operands.at(output_index).shape().dim(1);
356
357 OP_REQUIRES(operands.at(output_index).shape().rank() == 2 &&
358 operands.at(hidden_state_out_index).shape().rank() == 2 &&
359 operands.at(input_index).shape().rank() == 2 &&
360 operands.at(weights_index).shape().rank() == 2 &&
361 operands.at(recurrent_weights_index).shape().rank() == 2 &&
362 operands.at(hidden_state_in_index).shape().rank() == 2);
363 OP_REQUIRES(operands.at(bias_index).shape().rank() == 1);
364
365 OP_REQUIRES(batch_size == operands.at(input_index).shape().dim(0) &&
366 batch_size == operands.at(hidden_state_in_index).shape().dim(0) &&
367 batch_size == operands.at(hidden_state_out_index).shape().dim(0));
368 OP_REQUIRES(operands.at(input_index).shape().dim(1) == operands.at(weights_index).shape().dim(1));
369
370 OP_REQUIRES(num_units == operands.at(weights_index).shape().dim(0) &&
371 num_units == operands.at(recurrent_weights_index).shape().dim(0) &&
372 num_units == operands.at(bias_index).shape().dim(0));
373 OP_REQUIRES(num_units == operands.at(output_index).shape().dim(1) &&
374 num_units == operands.at(recurrent_weights_index).shape().dim(1) &&
375 num_units == operands.at(hidden_state_in_index).shape().dim(1) &&
376 num_units == operands.at(hidden_state_out_index).shape().dim(1));
377}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RNN::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RNN::HIDDEN_STATE_IN, onert::ir::operation::RNN::HIDDEN_STATE_OUT, onert::ir::operation::RNN::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::RNN::OUTPUT, onert::ir::operation::RNN::RECURRENT_WEIGHTS, and onert::ir::operation::RNN::WEIGHTS.

◆ visit() [33/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RoPE node)
override

Definition at line 1142 of file ShapeValidator.cc.

1143{
1144 const auto &operands = _graph.operands();
1145 const auto ofm_index{node.getOutputs().at(0)};
1146 if (operands.at(ofm_index).info().isDynamic())
1147 return;
1148
1149 const auto ifm_index{node.getInputs().at(ir::operation::RoPE::Input::INPUT)};
1150 const auto sin_table_index{node.getInputs().at(ir::operation::RoPE::Input::SIN_TABLE)};
1151 const auto cos_table_index{node.getInputs().at(ir::operation::RoPE::Input::COS_TABLE)};
1152
1153 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
1154 OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
1155 OP_REQUIRES(operands.at(sin_table_index).shape().rank() == 4);
1156 OP_REQUIRES(operands.at(cos_table_index).shape().rank() == 4);
1157}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RoPE::COS_TABLE, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RoPE::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::RoPE::SIN_TABLE.

◆ visit() [34/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Select node)
override

Definition at line 931 of file ShapeValidator.cc.

932{
933 // TODO Shape validation of select
934}

◆ visit() [35/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Shape node)
override

Definition at line 967 of file ShapeValidator.cc.

968{
969 const auto &operands = _graph.operands();
970 const auto output_index{node.getOutputs().at(0)};
971 if (operands.at(output_index).info().isDynamic())
972 return;
973
974 [[maybe_unused]] const auto input_index{node.getInputs().at(0)};
975 OP_REQUIRES(operands.at(output_index).shape().rank() == 1);
976}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [36/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Softmax node)
override

Definition at line 216 of file ShapeValidator.cc.

217{
218 const auto &operands = _graph.operands();
219 const auto output_index{node.getOutputs().at(0)};
220 if (operands.at(output_index).info().isDynamic())
221 return;
222
223 const auto input_index{node.getInputs().at(0)};
224
225 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
226}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [37/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SpaceToBatchND node)
override

Definition at line 379 of file ShapeValidator.cc.

380{
381 const auto &operands = _graph.operands();
382 const auto ofm_index{node.getOutputs().at(0)};
383 if (operands.at(ofm_index).info().isDynamic())
384 return;
385
386 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
387 const auto block_size_index{
389 const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
390
391 const auto input_shape = operands.at(ifm_index).shape().asFeature();
392 const auto output_shape = operands.at(ofm_index).shape().asFeature();
393
394 // All requirement as per NNAPI specification.
395 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
396 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
397 OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
398 OP_REQUIRES(operands.at(paddings_index).shape().rank() == 2);
399
400 OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
401 OP_REQUIRES(operands.at(paddings_index).shape().dim(0) == 2);
402 OP_REQUIRES(operands.at(paddings_index).shape().dim(1) == 2);
403
404 OP_REQUIRES(input_shape.C == output_shape.C);
405}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::SpaceToBatchND::BLOCK_SIZE, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SpaceToBatchND::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SpaceToBatchND::PADDINGS.

◆ visit() [38/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SpaceToDepth node)
override

Definition at line 407 of file ShapeValidator.cc.

408{
409 const auto &operands = _graph.operands();
410 const auto ofm_index{node.getOutputs().at(0)};
411 if (operands.at(ofm_index).info().isDynamic())
412 return;
413
414 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
415
416 const auto input_shape = operands.at(ifm_index).shape().asFeature();
417 const auto output_shape = operands.at(ofm_index).shape().asFeature();
418 const auto block_size = node.param().block_size;
419
420 // All assertions as per NNAPI specification.
421 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
422 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
423 OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0));
424 OP_REQUIRES(input_shape.N == output_shape.N);
425 OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
426}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::SpaceToDepth::Param::block_size, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SpaceToDepth::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SpaceToDepth::param().

◆ visit() [39/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Split node)
override

Definition at line 948 of file ShapeValidator.cc.

949{
950 const auto &operands = _graph.operands();
951 const auto output_index{node.getOutputs().at(0)};
952 if (operands.at(output_index).info().isDynamic())
953 return;
954
955 const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
956 const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)};
957
958 const auto num_splits = node.param().num_splits;
959 const auto input_rank = operands.at(input_index).shape().rank();
960 auto axis = *reinterpret_cast<const int32_t *>(operands.at(axis_index).data()->base());
961 axis = axis < 0 ? axis + input_rank : axis;
962
963 OP_REQUIRES(axis >= 0 && axis < input_rank);
964 OP_REQUIRES(operands.at(input_index).shape().dim(axis) % num_splits == 0);
965}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Split::AXIS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Split::INPUT, onert::ir::operation::Split::Param::num_splits, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Split::param().

◆ visit() [40/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SquaredDifference node)
override

Definition at line 1014 of file ShapeValidator.cc.

1015{
1016 const auto &operands = _graph.operands();
1017 const auto output_index{node.getOutputs().at(0)};
1018 const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
1019 const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
1020
1021 // Check for dimension constraints
1022 if (operands.at(output_index).info().isDynamic())
1023 return;
1024
1025 auto output_shape = operands.at(output_index).shape();
1026 auto lhs_shape = operands.at(lhs_index).shape();
1027 auto rhs_shape = operands.at(rhs_index).shape();
1028 // Check for output rank
1029 OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
1030 auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
1031
1032 for (int idx = 1; idx <= min_rank; idx++)
1033 {
1034 int l_idx = lhs_shape.rank() - idx;
1035 int r_idx = rhs_shape.rank() - idx;
1036 int out_idx = output_shape.rank() - idx;
1037
1038 OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
1039
1040 auto l_dims = lhs_shape.dim(l_idx);
1041 auto r_dims = rhs_shape.dim(r_idx);
1042 auto out_dims = output_shape.dim(out_idx);
1043
1044 OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
1045 ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
1046 }
1047 auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
1048 for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
1049 {
1050 int out_idx = output_shape.rank() - idx;
1051 int tmp_idx = tmp_shape.rank() - idx;
1052
1053 OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
1054 (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
1055 }
1056}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SquaredDifference::LHS, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SquaredDifference::RHS.

◆ visit() [41/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::StridedSlice node)
override

Definition at line 936 of file ShapeValidator.cc.

937{
938 const auto &operands = _graph.operands();
939 const auto output_index{node.getOutputs().at(0)};
940 const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
941
942 if (operands.at(output_index).info().isDynamic())
943 return;
944
945 OP_REQUIRES(operands.at(input_index).shape().rank() <= 4);
946}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::StridedSlice::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [42/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Tile node)
override

Definition at line 1057 of file ShapeValidator.cc.

1058{
1059 const auto &operands = _graph.operands();
1060 const auto output_index{node.getOutputs().at(0)};
1061 if (operands.at(output_index).info().isDynamic())
1062 return;
1063
1064 const auto input_index{node.getInputs().at(0)};
1065 const auto multiple_index{node.getInputs().at(1)};
1066
1067 OP_REQUIRES(operands.at(multiple_index).shape().rank() == 1);
1068 OP_REQUIRES(operands.at(multiple_index).shape().dim(0) ==
1069 operands.at(input_index).shape().rank());
1070 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
1071}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [43/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Transpose node)
override

Definition at line 316 of file ShapeValidator.cc.

317{
318 const auto &operands = _graph.operands();
319 const auto output_index{node.getOutputs().at(0)};
320 if (operands.at(output_index).info().isDynamic())
321 return;
322
323 const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
324 const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
325
326 const auto &output_shape = operands.at(output_index).shape();
327 const auto &input_shape = operands.at(input_index).shape();
328
329 OP_REQUIRES(operands.at(perm_index).shape().num_elements() == 0 ||
330 input_shape.rank() ==
331 static_cast<int>(operands.at(perm_index).shape().num_elements()));
332 OP_REQUIRES(input_shape.rank() == output_shape.rank());
333}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Transpose::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::Transpose::PERMUTATION.

◆ visit() [44/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::TransposeConv node)
override

Definition at line 520 of file ShapeValidator.cc.

521{
522 // shape check
523 const auto &operands = _graph.operands();
524 const auto ofm_index{node.getOutputs().at(0)};
525
526 if (operands.at(ofm_index).info().isDynamic())
527 return;
528
529 const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
530 const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
531
532 // Only 4D tensors are supported
533 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
534 OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ifm_index).shape().rank());
535 OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ker_index).shape().rank());
536
537 const auto ofm_shape = operands.at(ofm_index).shape().asFeature();
538 const auto ifm_shape = operands.at(ifm_index).shape().asFeature();
539 // The kernel has only IHWO layout on frontend
540 // So ker_shape is treated here below
541 // I -> N
542 // H -> H
543 // W -> W
544 // O -> C
545 const auto ker_shape = operands.at(ker_index).shape().asFeature();
546
547 OP_REQUIRES(ifm_shape.N == ofm_shape.N);
548 OP_REQUIRES(ifm_shape.C == ker_shape.C);
549 OP_REQUIRES(ker_shape.N == ofm_shape.C);
550}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::TransposeConv::INPUT, onert::ir::operation::TransposeConv::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [45/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Unpack node)
override

Definition at line 894 of file ShapeValidator.cc.

895{
896 const auto &operands = _graph.operands();
897 const auto axis{node.param().axis};
898 const auto output_index{node.getInputs().at(0)};
899 if (operands.at(output_index).info().isDynamic())
900 return;
901
902 const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
903
904 const auto &input_shape = operands.at(input_index).shape();
905 const auto input_rank = static_cast<int32_t>(input_shape.rank());
906
907 OP_REQUIRES(axis >= -input_rank && axis < input_rank);
908}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Unpack::Param::axis, onert::ir::Operation::getInputs(), onert::ir::operation::Unpack::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Unpack::param().

◆ visit() [46/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::While node)
override

Definition at line 1008 of file ShapeValidator.cc.

1009{
1010 // This validator does not check shape. So checking isDynamic() is skipped.
1011 // TODO Add to validate with subgraphs
1012}

The documentation for this class was generated from the following files: