ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::compiler::ShapeValidator Class Reference

#include <ShapeValidator.h>

Collaboration diagram for onert::compiler::ShapeValidator:

Public Member Functions

 ShapeValidator (void)=delete
 
 ShapeValidator (const ir::Graph &graph)
 
 ShapeValidator (const ShapeValidator &)=delete
 
 ShapeValidator (ShapeValidator &&)=delete
 
 ~ShapeValidator ()=default
 
ShapeValidatoroperator= (const ShapeValidator &)=delete
 
ShapeValidatoroperator= (ShapeValidator &&)=delete
 
void operator() ()
 
void visit (const ir::operation::BatchMatMul &node) override
 
void visit (const ir::operation::BatchToSpaceND &node) override
 
void visit (const ir::operation::BCQFullyConnected &node) override
 
void visit (const ir::operation::BCQGather &node) override
 
void visit (const ir::operation::Conv2D &node) override
 
void visit (const ir::operation::Comparison &node) override
 
void visit (const ir::operation::DepthwiseConv2D &node) override
 
void visit (const ir::operation::FullyConnected &node) override
 
void visit (const ir::operation::Softmax &node) override
 
void visit (const ir::operation::InstanceNorm &node) override
 
void visit (const ir::operation::Permute &node) override
 
void visit (const ir::operation::Pool2D &node) override
 
void visit (const ir::operation::Reduce &node) override
 
void visit (const ir::operation::Transpose &node) override
 
void visit (const ir::operation::RNN &node) override
 
void visit (const ir::operation::SpaceToBatchND &node) override
 
void visit (const ir::operation::SpaceToDepth &node) override
 
void visit (const ir::operation::ElementwiseActivation &node) override
 
void visit (const ir::operation::ElementwiseBinary &node) override
 
void visit (const ir::operation::ElementwiseUnary &node) override
 
void visit (const ir::operation::EmbeddingLookup &node) override
 
void visit (const ir::operation::ExpandDims &node) override
 
void visit (const ir::operation::HashtableLookup &node) override
 
void visit (const ir::operation::TransposeConv &node) override
 
void visit (const ir::operation::Gather &node) override
 
void visit (const ir::operation::DepthToSpace &node) override
 
void visit (const ir::operation::Pack &node) override
 
void visit (const ir::operation::LSTM &node) override
 
void visit (const ir::operation::L2Normalization &node) override
 
void visit (const ir::operation::Unpack &node) override
 
void visit (const ir::operation::Pad &node) override
 
void visit (const ir::operation::Select &node) override
 
void visit (const ir::operation::StridedSlice &node) override
 
void visit (const ir::operation::Split &node) override
 
void visit (const ir::operation::Shape &node) override
 
void visit (const ir::operation::ResizeBilinear &node) override
 
void visit (const ir::operation::Reverse &node) override
 
void visit (const ir::operation::If &node) override
 
void visit (const ir::operation::While &node) override
 
void visit (const ir::operation::SquaredDifference &node) override
 
void visit (const ir::operation::Tile &node) override
 
void visit (const ir::operation::Range &node) override
 
void visit (const ir::operation::MatrixBandPart &node) override
 
void visit (const ir::operation::LogSoftmax &node) override
 
void visit (const ir::operation::RmsNorm &node) override
 
void visit (const ir::operation::RoPE &node) override
 
- Public Member Functions inherited from onert::ir::OperationVisitor
virtual ~OperationVisitor ()=default
 

Detailed Description

Definition at line 37 of file ShapeValidator.h.

Constructor & Destructor Documentation

◆ ShapeValidator() [1/4]

onert::compiler::ShapeValidator::ShapeValidator ( void  )
delete

◆ ShapeValidator() [2/4]

onert::compiler::ShapeValidator::ShapeValidator ( const ir::Graph graph)

Definition at line 36 of file ShapeValidator.cc.

36: _graph{graph} {}

◆ ShapeValidator() [3/4]

onert::compiler::ShapeValidator::ShapeValidator ( const ShapeValidator )
delete

◆ ShapeValidator() [4/4]

onert::compiler::ShapeValidator::ShapeValidator ( ShapeValidator &&  )
delete

◆ ~ShapeValidator()

onert::compiler::ShapeValidator::~ShapeValidator ( )
default

Member Function Documentation

◆ operator()()

void onert::compiler::ShapeValidator::operator() ( )

Definition at line 51 of file ShapeValidator.cc.

52{
53 _graph.operations().iterate(
54 [&](const ir::OperationIndex &, const ir::IOperation &node) { node.accept(*this); });
55}
const Operations & operations() const override
Definition Graph.h:114
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
::onert::util::Index< uint32_t, OperationIndexTag > OperationIndex
Definition Index.h:32

References onert::ir::IOperation::accept(), onert::util::ObjectManager< Index, Object >::iterate(), and onert::ir::Graph::operations().

◆ operator=() [1/2]

ShapeValidator & onert::compiler::ShapeValidator::operator= ( const ShapeValidator )
delete

◆ operator=() [2/2]

ShapeValidator & onert::compiler::ShapeValidator::operator= ( ShapeValidator &&  )
delete

◆ visit() [1/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BatchMatMul node)
override

Definition at line 57 of file ShapeValidator.cc.

58{
59 const auto &operands = _graph.operands();
60 const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
61 const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
62 const auto out_index{node.getOutputs().at(0)};
63
64 if (operands.at(out_index).info().isDynamic())
65 return;
66
67 OP_REQUIRES(operands.at(lhs_index).shape().rank() <= 4);
68 OP_REQUIRES(operands.at(rhs_index).shape().rank() <= 4);
69 OP_REQUIRES(operands.at(lhs_index).shape().rank() >= 2);
70 OP_REQUIRES(operands.at(rhs_index).shape().rank() >= 2);
71}
#define OP_REQUIRES(EXP)
const Operands & operands() const override
Definition Graph.h:112

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BatchMatMul::LHS, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::BatchMatMul::RHS.

◆ visit() [2/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BatchToSpaceND node)
override

Definition at line 73 of file ShapeValidator.cc.

74{
75 const auto &operands = _graph.operands();
76 const auto ofm_index{node.getOutputs().at(0)};
77 if (operands.at(ofm_index).info().isDynamic())
78 return;
79
80 const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
81 const auto block_size_index{
83
84 const auto input_shape = operands.at(ifm_index).shape().asFeature();
85 const auto output_shape = operands.at(ofm_index).shape().asFeature();
86
87 // All requirement as per NNAPI specification.
88 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
89 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
90 OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
91
92 OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
93
94 if (node.getInputs().size() != 2)
95 {
96 const auto crops_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::CROPS_DATA)};
97 OP_REQUIRES(operands.at(crops_index).shape().rank() == 2);
98 OP_REQUIRES(operands.at(crops_index).shape().dim(0) ==
99 (operands.at(ifm_index).shape().rank() - 2));
100 OP_REQUIRES(operands.at(crops_index).shape().dim(1) == 2);
101 }
102
103 OP_REQUIRES(input_shape.C == output_shape.C);
104}
const Object & at(const Index &index) const
Get the object that is associated with the given index.
const luci_interpreter::RuntimeShape output_shape

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::BatchToSpaceND::BLOCK_SIZE, onert::ir::operation::BatchToSpaceND::CROPS_DATA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BatchToSpaceND::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::OperandIndexSequence::size().

◆ visit() [3/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BCQFullyConnected node)
override

Definition at line 106 of file ShapeValidator.cc.

107{
108 const auto &operands = _graph.operands();
109 const auto ofm_index{node.getOutputs().at(0)};
110 if (operands.at(ofm_index).info().isDynamic())
111 return;
112
113 const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
114 const auto weight_scales_index{
116 const auto weight_binary_index{
118 const auto weight_cluster_index{
120 const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)};
121
122 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 2);
123 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 2);
124 OP_REQUIRES(operands.at(weight_scales_index).shape().rank() == 1);
125 OP_REQUIRES(operands.at(weight_binary_index).shape().rank() == 2);
126 OP_REQUIRES(operands.at(weight_cluster_index).shape().rank() == 2);
127
128 OP_REQUIRES(operands.at(ifm_index).shape().dim(1) == operands.at(ofm_index).shape().dim(1));
129
130 OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(0) > 0);
131 OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(1) == 2);
132
133 // more shape validation will be done inside kernel.
134
135 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
136}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::BCQFullyConnected::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BCQFullyConnected::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::BCQFullyConnected::WEIGHTS_BINARY, onert::ir::operation::BCQFullyConnected::WEIGHTS_CLUSTERS, and onert::ir::operation::BCQFullyConnected::WEIGHTS_SCALES.

◆ visit() [4/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BCQGather node)
override

Definition at line 138 of file ShapeValidator.cc.

139{
140 const auto &operands = _graph.operands();
141 const auto ofm_index{node.getOutputs().at(0)};
142 if (operands.at(ofm_index).info().isDynamic())
143 return;
144
145 const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
146 const auto input_binary_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
147 const auto input_scales_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_SCALES)};
148 const auto input_clusters_index{
150
151 OP_REQUIRES(operands.at(indices_index).shape().rank() <=
152 2); // TODO : support rank up to 4 or more
153 OP_REQUIRES(operands.at(input_binary_index).shape().rank() == 2);
154 OP_REQUIRES(operands.at(input_scales_index).shape().rank() == 1);
155 OP_REQUIRES(operands.at(input_clusters_index).shape().rank() == 2);
156
157 OP_REQUIRES(operands.at(input_clusters_index).shape().dim(0) > 0);
158 OP_REQUIRES(operands.at(input_clusters_index).shape().dim(1) == 2);
159
160 // more shape validation will be done inside kernel.
161}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BCQGather::INDICES, onert::ir::operation::BCQGather::INPUT_BINARY, onert::ir::operation::BCQGather::INPUT_CLUSTERS, onert::ir::operation::BCQGather::INPUT_SCALES, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [5/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Comparison node)
override

Definition at line 180 of file ShapeValidator.cc.

181{
182 // TODO Shape validation of comparison
183}

◆ visit() [6/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Conv2D node)
override

Definition at line 163 of file ShapeValidator.cc.

164{
165 const auto &operands = _graph.operands();
166 const auto ofm_index{node.getOutputs().at(0)};
167 if (operands.at(ofm_index).info().isDynamic())
168 return;
169
170 const auto ifm_index{node.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
171 const auto ker_index{node.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
172 const auto bias_index{node.getInputs().at(ir::operation::Conv2D::Input::BIAS)};
173
174 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
175 OP_REQUIRES(operands.at(ker_index).shape().rank() == 4);
176 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
177 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
178}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Conv2D::INPUT, onert::ir::operation::Conv2D::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [7/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::DepthToSpace node)
override

Definition at line 573 of file ShapeValidator.cc.

574{
575 const auto &operands = _graph.operands();
576 int32_t block_size = node.param().block_size;
577
578 // shape check
579 const auto output_index{node.getOutputs().at(0)};
580 if (operands.at(output_index).info().isDynamic())
581 return;
582
583 const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
584
585 const auto output_shape = operands.at(output_index).shape().asFeature();
586 const auto input_shape = operands.at(input_index).shape().asFeature();
587
588 OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
589 OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
590
591 {
592 OP_REQUIRES(output_shape.N == input_shape.N);
593 OP_REQUIRES(output_shape.H == input_shape.H * block_size);
594 OP_REQUIRES(output_shape.W == input_shape.W * block_size);
595 OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
596 OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
597 }
598}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::DepthToSpace::Param::block_size, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthToSpace::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::DepthToSpace::param().

◆ visit() [8/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::DepthwiseConv2D node)
override

Definition at line 185 of file ShapeValidator.cc.

186{
187 const auto &operands = _graph.operands();
188 const auto ofm_index{node.getOutputs().at(0)};
189 if (operands.at(ofm_index).info().isDynamic())
190 return;
191
192 const auto ifm_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::INPUT)};
193 const auto ker_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::KERNEL)};
194 const auto bias_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::BIAS)};
195
196 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
197 OP_REQUIRES(operands.at(ker_index).shape().rank() == 4);
198 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
199 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
200}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::DepthwiseConv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthwiseConv2D::INPUT, onert::ir::operation::DepthwiseConv2D::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [9/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseActivation node)
override

Definition at line 430 of file ShapeValidator.cc.

430{ checkUnaryOp(node); }

◆ visit() [10/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseBinary node)
override

Definition at line 432 of file ShapeValidator.cc.

433{
434 // TODO Shape validation of ElementwiseBinary
435}

◆ visit() [11/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseUnary node)
override

Definition at line 437 of file ShapeValidator.cc.

438{
439 const auto &operands = _graph.operands();
440 const auto output_index{node.getOutputs().at(0)};
441 const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
442
443 if (operands.at(output_index).info().isDynamic())
444 return;
445
446 OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
447}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ElementwiseUnary::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [12/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::EmbeddingLookup node)
override

Definition at line 449 of file ShapeValidator.cc.

450{
451 const auto &operands = _graph.operands();
452 const auto output_index{node.getOutputs().at(0)};
453 const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
454 const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
455
456 const auto &output_obj = operands.at(output_index);
457 const auto &lookups_obj = operands.at(lookups_index);
458 const auto &values_obj = operands.at(values_index);
459
460 // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
461 // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
462 {
463 if (operands.at(output_index).info().isDynamic())
464 return;
465
466 const auto &output_shape = output_obj.shape();
467 const auto &lookups_shape = lookups_obj.shape();
468 const auto &values_shape = values_obj.shape();
469
470 OP_REQUIRES(lookups_shape.rank() == 1);
471 OP_REQUIRES(values_shape.rank() >= 2);
472
473 // output should be a n-D tensor with the same rank and shape as the values tensor, except for
474 // the first dimension which has the same size as lookups' only dimension.
475 OP_REQUIRES(output_shape.rank() == values_shape.rank());
476 OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
477 for (int n = 1; n < output_shape.rank(); ++n)
478 {
479 OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
480 }
481 }
482}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::EmbeddingLookup::LOOKUPS, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::EmbeddingLookup::VALUES.

◆ visit() [13/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ExpandDims node)
override

Definition at line 484 of file ShapeValidator.cc.

485{
486 const auto &operands = _graph.operands();
487 const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
488
489 if (operands.at(axis_index).info().isDynamic())
490 return;
491 OP_REQUIRES(operands.at(axis_index).shape().rank() <= 1);
492}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::ExpandDims::AXIS, onert::ir::Operation::getInputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [14/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::FullyConnected node)
override

Definition at line 202 of file ShapeValidator.cc.

203{
204 const auto &operands = _graph.operands();
205 const auto ofm_index{node.getOutputs().at(0)};
206 if (operands.at(ofm_index).info().isDynamic())
207 return;
208
209 const auto ifm_index{node.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
210 const auto ker_index{node.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
211 const auto bias_index{node.getInputs().at(ir::operation::FullyConnected::Input::BIAS)};
212
213 OP_REQUIRES(operands.at(ifm_index).shape().rank() >= 2);
214 OP_REQUIRES(operands.at(ker_index).shape().rank() == 2);
215 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
216}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::FullyConnected::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::FullyConnected::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::FullyConnected::WEIGHT.

◆ visit() [15/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Gather node)
override

Definition at line 554 of file ShapeValidator.cc.

555{
556 const auto &operands = _graph.operands();
557 const auto ofm_index{node.getOutputs().at(0)};
558 if (operands.at(ofm_index).info().isDynamic())
559 return;
560
561 const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
562 const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
563
564 const auto &ifm_shape = operands.at(ifm_index).shape();
565 const auto &indices_shape = operands.at(indices_index).shape();
566 const auto &ofm_shape = operands.at(ofm_index).shape();
567
568 OP_REQUIRES(ifm_shape.rank() <= 4);
569 OP_REQUIRES(indices_shape.rank() <= 3);
570 OP_REQUIRES(ofm_shape.rank() <= 4);
571}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Gather::INDICES, onert::ir::operation::Gather::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [16/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::HashtableLookup node)
override

Definition at line 494 of file ShapeValidator.cc.

495{
496 const auto &operands = _graph.operands();
497 const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
498 const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
499 const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
500 const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
501
502 const auto &output_obj = operands.at(output_index);
503 const auto &lookups_obj = operands.at(lookups_index);
504 const auto &keys_obj = operands.at(keys_index);
505 const auto &values_obj = operands.at(values_index);
506
507 if (operands.at(output_index).info().isDynamic())
508 return;
509
510 const auto &output_shape = output_obj.shape();
511 const auto &lookups_shape = lookups_obj.shape();
512 const auto &keys_shape = keys_obj.shape();
513 const auto &values_shape = values_obj.shape();
514
515 OP_REQUIRES(values_shape.rank() == output_shape.rank());
516 OP_REQUIRES(lookups_shape.rank() == 1);
517 OP_REQUIRES(keys_shape.rank() == 1);
518 OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
519 OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
520}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::HashtableLookup::KEYS, onert::ir::operation::HashtableLookup::LOOKUPS, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::HashtableLookup::OUTPUT, output_shape, and onert::ir::operation::HashtableLookup::VALUES.

◆ visit() [17/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::If node)
override

Definition at line 1006 of file ShapeValidator.cc.

1007{
1008 // TODO Add to validate with subgraphs
1009}

◆ visit() [18/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::InstanceNorm node)
override

Definition at line 230 of file ShapeValidator.cc.

231{
232 const auto &operands = _graph.operands();
233 const auto ofm_index{node.getOutputs().at(0)};
234 if (operands.at(ofm_index).info().isDynamic())
235 return;
236
237 const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
238 const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
239 const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
240
241 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
242 OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
243 OP_REQUIRES(operands.at(gamma_index).shape().rank() == 1);
244 OP_REQUIRES(operands.at(beta_index).shape().rank() == 1);
245}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::InstanceNorm::BETA, onert::ir::operation::InstanceNorm::GAMMA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::InstanceNorm::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [19/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::L2Normalization node)
override

Definition at line 877 of file ShapeValidator.cc.

878{
879 const auto &operands = _graph.operands();
880 const auto ofm_index{node.getOutputs().at(0)};
881 if (operands.at(ofm_index).info().isDynamic())
882 return;
883
884 const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
885
886 auto ifm_shape = operands.at(ifm_index).shape();
887 auto ofm_shape = operands.at(ofm_index).shape();
888
889 OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
890
891 for (auto i = 0; i < ifm_shape.rank(); i++)
892 {
893 OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
894 }
895}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::L2Normalization::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [20/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::LogSoftmax node)
override

Definition at line 1112 of file ShapeValidator.cc.

1113{
1114 const auto &operands = _graph.operands();
1115 const auto output_index{node.getOutputs().at(0)};
1116 if (operands.at(output_index).info().isDynamic())
1117 return;
1118
1119 const auto input_index{node.getInputs().at(0)};
1120
1121 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
1122}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [21/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::LSTM node)
override

Definition at line 622 of file ShapeValidator.cc.

623{
624 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
625 // TODO Support dynamic rnn
626 const auto &operands = _graph.operands();
627 const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
628 if (operands.at(output_index).info().isDynamic())
629 return;
630
631 const auto scratch_buffer_index{
632 node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; // Optional
633 const auto output_state_out_index{
634 node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; // Optional
635 const auto cell_state_out_index{
636 node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; // Optional
637
638 const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
639 const auto input_to_input_weights_index{
640 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // Optional
641 const auto input_to_forget_weights_index{
643 const auto input_to_cell_weights_index{
645 const auto input_to_output_weights_index{
647 const auto recurrent_to_input_weights_index{
648 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // Optional
649 const auto recurrent_to_forget_weights_index{
651 const auto recurrent_to_cell_weights_index{
653 const auto recurrent_to_output_weights_index{
655 const auto cell_to_input_weights_index{
656 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // Optional
657 const auto cell_to_forget_weights_index{
658 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // Optional
659 const auto cell_to_output_weights_index{
660 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // Optional
661 const auto input_gate_bias_index{
662 node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; // Optional
663 const auto forget_gate_bias_index{
665 const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
666 const auto output_gate_bias_index{
668 const auto projection_weights_index{
669 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // Optional
670 const auto projection_bias_index{
671 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // Optional
672 const auto output_state_in_index{
674 const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
675
676 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
677 for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
678 {
679 OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
680 operands.at(output_index).shape().dim(i));
681 }
682 OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
683 operands.at(output_index).shape().rank() == 3) &&
684 (operands.at(input_index).shape().rank() == 2 ||
685 operands.at(input_index).shape().rank() == 3) &&
686 (!operands.exist(input_to_input_weights_index) ||
687 operands.at(input_to_input_weights_index).shape().rank() == 2) &&
688 operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
689 operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
690 operands.at(input_to_output_weights_index).shape().rank() == 2 &&
691 (!operands.exist(recurrent_to_input_weights_index) ||
692 operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
693 operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
694 operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
695 operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
696 (!operands.exist(projection_weights_index) ||
697 operands.at(projection_weights_index).shape().rank() == 2) &&
698 operands.at(output_state_in_index).shape().rank() == 2 &&
699 operands.at(cell_state_in_index).shape().rank() == 2);
700
701 OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
702 operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
703 (!operands.exist(cell_to_forget_weights_index) ||
704 operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
705 (!operands.exist(cell_to_output_weights_index) ||
706 operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
707 (!operands.exist(input_gate_bias_index) ||
708 operands.at(input_gate_bias_index).shape().rank() == 1) &&
709 operands.at(forget_gate_bias_index).shape().rank() == 1 &&
710 operands.at(cell_bias_index).shape().rank() == 1 &&
711 operands.at(output_gate_bias_index).shape().rank() == 1 &&
712 (!operands.exist(projection_bias_index) ||
713 operands.at(projection_bias_index).shape().rank() == 1));
714
715 // CIFG assertion
716 OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
717 (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
718 operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
719 (!operands.exist(recurrent_to_input_weights_index) ||
720 (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
721 operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
722 (!operands.exist(input_gate_bias_index) ||
723 operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
724 (!operands.exist(cell_to_input_weights_index) ||
725 operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
726 ((operands.exist(input_to_input_weights_index) &&
727 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
728 operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
729 (operands.exist(recurrent_to_input_weights_index) &&
730 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
731 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
732 (operands.exist(input_gate_bias_index) &&
733 operands.at(input_gate_bias_index).shape().dim(0) != 0)));
734
735 // Peephole assertion
736 OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
737 operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
738 (!operands.exist(cell_to_output_weights_index) ||
739 operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
740 ((operands.exist(cell_to_forget_weights_index) &&
741 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
742 (operands.exist(cell_to_output_weights_index) &&
743 operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
744
745 bool has_input_to_input_weights =
746 operands.exist(input_to_input_weights_index) &&
747 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
748 operands.at(input_to_input_weights_index).shape().dim(1) != 0);
749 bool has_recurrent_to_input_weights =
750 operands.exist(recurrent_to_input_weights_index) &&
751 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
752 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
753 bool has_input_gate_bias =
754 operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
755 bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
756 operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
757 bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
758 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
759 bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
760 operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
761 bool has_projection_weights = operands.exist(projection_weights_index) &&
762 (operands.at(projection_weights_index).shape().dim(0) != 0 &&
763 operands.at(projection_weights_index).shape().dim(1) != 0);
764 bool has_projection_bias =
765 operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
766
767 // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
768 // true: no CIFG
769 // false: CIFG
770 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
771
772 // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
773 // true: peephole
774 // false: no peephole
775 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
776
777 // NOTE The projection weights may have data but the projection bias may not.
778 bool has_projection_param = has_projection_weights;
779
780 const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
781 ? operands.at(input_index).shape().dim(1)
782 : operands.at(input_index).shape().dim(0);
783 OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
784 batch_size == operands.at(cell_state_in_index).shape().dim(0));
785
786 const auto input_size =
787 operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
788 OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
789 input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
790 input_size == operands.at(input_to_output_weights_index).shape().dim(1));
791
792 const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
793 OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
794 num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
795 num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
796 num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
797 num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
798 num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
799 num_units == operands.at(cell_bias_index).shape().dim(0) &&
800 num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
801 num_units == operands.at(cell_state_in_index).shape().dim(1));
802
803 const auto output_size =
804 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
805 OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
806 output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
807 output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
808 output_size == operands.at(output_state_in_index).shape().dim(1));
809
810 if (has_cifg_param)
811 {
812 OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
814 num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
815 num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
816 ((operands.exist(cell_to_input_weights_index) &&
817 num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
818 (!operands.exist(cell_to_input_weights_index) ||
819 operands.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
820 num_units == operands.at(input_gate_bias_index).shape().dim(0));
821 OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
822 OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
823 has_input_gate_bias);
824 if (has_cell_to_input_weights)
825 {
826 // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
827 OP_REQUIRES(has_peephole_param);
828 }
829 if (operands.exist(scratch_buffer_index))
830 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
831 }
832 else
833 {
834 if (operands.exist(scratch_buffer_index))
835 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
836 }
837
838 if (has_peephole_param)
839 {
840 OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
841 num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
842 (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
843 operands.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
844 }
845
846 if (has_projection_param)
847 {
848 OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
849 OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
850 if (has_projection_bias)
851 {
852 OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
853 }
854 }
855
856 if (operands.exist(scratch_buffer_index))
857 {
858 OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
859 OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
860 }
861
862 if (operands.exist(output_state_out_index))
863 {
864 OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
865 OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
866 OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
867 }
868
869 if (operands.exist(cell_state_out_index))
870 {
871 OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
872 OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
873 OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
874 }
875}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::LSTM::CELL_BIAS, onert::ir::operation::LSTM::CELL_STATE_IN, onert::ir::operation::LSTM::CELL_STATE_OUT, onert::ir::operation::LSTM::CELL_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::CELL_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::CELL_TO_OUTPUT_WEIGHTS, onert::ir::operation::LSTM::FORGET_GATE_BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::LSTM::INPUT, onert::ir::operation::LSTM::INPUT_GATE_BIAS, onert::ir::operation::LSTM::INPUT_TO_CELL_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_OUTPUT_WEIGHTS, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::LSTM::OUTPUT, onert::ir::operation::LSTM::OUTPUT_GATE_BIAS, onert::ir::operation::LSTM::OUTPUT_STATE_IN, onert::ir::operation::LSTM::OUTPUT_STATE_OUT, onert::ir::operation::LSTM::param(), onert::ir::operation::LSTM::PROJECTION_BIAS, onert::ir::operation::LSTM::PROJECTION_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_CELL_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_OUTPUT_WEIGHTS, onert::ir::operation::LSTM::SCRATCH_BUFFER, and onert::ir::operation::LSTM::Param::time_major.

◆ visit() [22/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::MatrixBandPart node)
override

Definition at line 1093 of file ShapeValidator.cc.

1094{
1095 const auto &operands = _graph.operands();
1096 const auto output_index{node.getOutputs().at(0)};
1097 const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
1098 const auto num_lower_index{
1100 const auto num_upper_index{
1102
1103 // Check for dimension constraints
1104 if (operands.at(output_index).info().isDynamic())
1105 return;
1106
1107 OP_REQUIRES(operands.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix
1108 OP_REQUIRES(operands.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
1109 OP_REQUIRES(operands.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
1110}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::MatrixBandPart::INPUT, onert::ir::operation::MatrixBandPart::NUM_LOWER_DIAG, onert::ir::operation::MatrixBandPart::NUM_UPPER_DIAG, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [23/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pack node)
override

Definition at line 600 of file ShapeValidator.cc.

601{
602 const auto &operands = _graph.operands();
603 const auto axis{node.param().axis};
604 const auto output_index{node.getOutputs().at(0)};
605 if (operands.at(output_index).info().isDynamic())
606 return;
607
608 // shape check
609 const auto &output_shape = operands.at(output_index).shape();
610 const auto output_rank = static_cast<int32_t>(output_shape.rank());
611
612 const auto input1_index{node.getInputs().at(0)};
613 const auto &input_shape = operands.at(input1_index).shape();
614
615 OP_REQUIRES(axis >= -output_rank && axis < output_rank);
616 for (const auto &index : node.getInputs())
617 {
618 OP_REQUIRES(input_shape == operands.at(index).shape());
619 }
620}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Pack::Param::axis, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::Pack::param().

◆ visit() [24/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pad node)
override

Definition at line 913 of file ShapeValidator.cc.

914{
915 const auto &operands = _graph.operands();
916 const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
917 OP_REQUIRES(operands.at(pad_index).typeInfo().type() == ir::DataType::INT32);
918
919 const auto output_index{node.getInputs().at(0)};
920 if (operands.at(output_index).info().isDynamic())
921 return;
922
923 const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
924
925 const auto &pad_shape = operands.at(pad_index).shape();
926 const auto input_rank = static_cast<int32_t>(operands.at(input_index).shape().rank());
927
928 OP_REQUIRES(pad_shape.rank() == 2);
929 OP_REQUIRES(pad_shape.dim(0) == input_rank);
930 OP_REQUIRES(pad_shape.dim(1) == 2);
931 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
932}
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
uint32_t rank(void) const
Definition TensorShape.h:35
loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleNode *paddings)

References onert::ir::OperandIndexSequence::at(), loco::TensorShape::dim(), onert::ir::Operation::getInputs(), onert::ir::operation::Pad::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::Pad::PAD, and loco::TensorShape::rank().

◆ visit() [25/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Permute node)
override

Definition at line 259 of file ShapeValidator.cc.

260{
261 const auto &operands = _graph.operands();
262 const auto output_index{node.getOutputs().at(0)};
263 if (operands.at(output_index).info().isDynamic())
264 return;
265
266 const auto input_index{node.getInputs().at(0)};
267
268 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
269}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [26/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pool2D node)
override

Definition at line 247 of file ShapeValidator.cc.

248{
249 const auto &operands = _graph.operands();
250 const auto ofm_index{node.getOutputs().at(0)};
251 if (operands.at(ofm_index).info().isDynamic())
252 return;
253
254 const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
255
256 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
257}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Pool2D::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [27/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Range node)
override

Definition at line 1076 of file ShapeValidator.cc.

1077{
1078 const auto &operands = _graph.operands();
1079 const auto output_index{node.getOutputs().at(0)};
1080 const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
1081 const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
1082 const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
1083
1084 // Check for dimension constraints
1085 if (operands.at(output_index).info().isDynamic())
1086 return;
1087
1088 OP_REQUIRES(operands.at(start_index).shape().rank() == 0);
1089 OP_REQUIRES(operands.at(limit_index).shape().rank() == 0);
1090 OP_REQUIRES(operands.at(delta_index).shape().rank() == 0);
1091}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Range::DELTA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Range::LIMIT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Range::START.

◆ visit() [28/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Reduce node)
override

Definition at line 271 of file ShapeValidator.cc.

272{
273 const auto &operands = _graph.operands();
274 const auto output_index{node.getOutputs().at(0)};
275 if (operands.at(output_index).info().isDynamic())
276 return;
277
278 const auto &input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
279 const auto &input_shape = operands.at(input_index).shape();
280 const auto &output_shape = operands.at(output_index).shape();
281
282 OP_REQUIRES(input_shape.rank() <= 4);
283 OP_REQUIRES(output_shape.rank() <= input_shape.rank());
284
285 // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
286 // supports cases reducing height and width or reducing depth.
287 // TODO We have to support all cases of dimensions up to 4.
288 // For correct permuting, we have to set output's shape to be equal in dimension position of the
289 // input. But the positions of the same dimensions in the input and output may be set differently.
290 // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
291 // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
292 // extend it in 4 dimensions, it should be {1,1,3,5}.
293 // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
294 // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
295 // next operation is not desired.
296 if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
297 {
298 if (output_shape.rank() == 2)
299 {
300 // Reducing HW
301 OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
302 input_shape.dim(3) == output_shape.dim(1));
303 }
304 else if (output_shape.rank() == 3)
305 {
306 // Reducing C or
307 // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
309 (input_shape.dim(0) == output_shape.dim(0) && input_shape.dim(1) == output_shape.dim(1) &&
310 input_shape.dim(2) == output_shape.dim(2)) ||
311 (input_shape.dim(0) == output_shape.dim(0) &&
312 (input_shape.dim(1) == output_shape.dim(1) || input_shape.dim(2) == output_shape.dim(1)) &&
313 input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
314 }
315 }
316}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reduce::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and output_shape.

◆ visit() [29/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ResizeBilinear node)
override

Definition at line 981 of file ShapeValidator.cc.

982{
983 const auto &operands = _graph.operands();
984 const auto output_index{node.getOutputs().at(0)};
985 const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
986
987 if (operands.at(output_index).info().isDynamic())
988 {
989 return;
990 }
991 OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
992 OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
993}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ResizeBilinear::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [30/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Reverse node)
override

Definition at line 995 of file ShapeValidator.cc.

996{
997 const auto &operands = _graph.operands();
998 const auto output_index{node.getOutputs().at(0)};
999 const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
1000
1001 if (operands.at(output_index).info().isDynamic())
1002 return;
1003 OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
1004}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reverse::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [31/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RmsNorm node)
override

Definition at line 1124 of file ShapeValidator.cc.

1125{
1126 const auto &operands = _graph.operands();
1127 const auto ofm_index{node.getOutputs().at(0)};
1128 if (operands.at(ofm_index).info().isDynamic())
1129 return;
1130
1131 const auto ifm_index{node.getInputs().at(ir::operation::RmsNorm::Input::INPUT)};
1132 const auto gamma_index{node.getInputs().at(ir::operation::RmsNorm::Input::GAMMA)};
1133
1134 const auto &ifm_shape = operands.at(ifm_index).shape();
1135 const auto &ofm_shape = operands.at(ofm_index).shape();
1136 const auto &gamma_shape = operands.at(gamma_index).shape();
1137
1138 OP_REQUIRES(ifm_shape.rank() == 3 || ifm_shape.rank() == 4);
1139 OP_REQUIRES(ifm_shape == ofm_shape);
1140 OP_REQUIRES(gamma_shape.rank() == 1);
1141 OP_REQUIRES((gamma_shape.dim(0) == 1) ||
1142 (gamma_shape.dim(0) == ifm_shape.dim(ifm_shape.rank() - 1)));
1143}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RmsNorm::GAMMA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RmsNorm::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [32/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RNN node)
override

Definition at line 337 of file ShapeValidator.cc.

338{
339 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
340 // TODO Support dynamic rnn
341 const auto &operands = _graph.operands();
342 const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
343 if (operands.at(output_index).info().isDynamic())
344 return;
345
346 const auto hidden_state_out_index{
348
349 const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
350 const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
351 const auto recurrent_weights_index{
353 const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
354 const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
355
356 const auto batch_size = operands.at(output_index).shape().dim(0);
357 const auto num_units = operands.at(output_index).shape().dim(1);
358
359 OP_REQUIRES(operands.at(output_index).shape().rank() == 2 &&
360 operands.at(hidden_state_out_index).shape().rank() == 2 &&
361 operands.at(input_index).shape().rank() == 2 &&
362 operands.at(weights_index).shape().rank() == 2 &&
363 operands.at(recurrent_weights_index).shape().rank() == 2 &&
364 operands.at(hidden_state_in_index).shape().rank() == 2);
365 OP_REQUIRES(operands.at(bias_index).shape().rank() == 1);
366
367 OP_REQUIRES(batch_size == operands.at(input_index).shape().dim(0) &&
368 batch_size == operands.at(hidden_state_in_index).shape().dim(0) &&
369 batch_size == operands.at(hidden_state_out_index).shape().dim(0));
370 OP_REQUIRES(operands.at(input_index).shape().dim(1) == operands.at(weights_index).shape().dim(1));
371
372 OP_REQUIRES(num_units == operands.at(weights_index).shape().dim(0) &&
373 num_units == operands.at(recurrent_weights_index).shape().dim(0) &&
374 num_units == operands.at(bias_index).shape().dim(0));
375 OP_REQUIRES(num_units == operands.at(output_index).shape().dim(1) &&
376 num_units == operands.at(recurrent_weights_index).shape().dim(1) &&
377 num_units == operands.at(hidden_state_in_index).shape().dim(1) &&
378 num_units == operands.at(hidden_state_out_index).shape().dim(1));
379}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RNN::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RNN::HIDDEN_STATE_IN, onert::ir::operation::RNN::HIDDEN_STATE_OUT, onert::ir::operation::RNN::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::RNN::OUTPUT, onert::ir::operation::RNN::RECURRENT_WEIGHTS, and onert::ir::operation::RNN::WEIGHTS.

◆ visit() [33/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RoPE node)
override

Definition at line 1145 of file ShapeValidator.cc.

1146{
1147 const auto &operands = _graph.operands();
1148 const auto ofm_index{node.getOutputs().at(0)};
1149 if (operands.at(ofm_index).info().isDynamic())
1150 return;
1151
1152 const auto ifm_index{node.getInputs().at(ir::operation::RoPE::Input::INPUT)};
1153 const auto sin_table_index{node.getInputs().at(ir::operation::RoPE::Input::SIN_TABLE)};
1154 const auto cos_table_index{node.getInputs().at(ir::operation::RoPE::Input::COS_TABLE)};
1155
1156 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
1157 OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
1158 OP_REQUIRES(operands.at(sin_table_index).shape().rank() == 4);
1159 OP_REQUIRES(operands.at(cos_table_index).shape().rank() == 4);
1160}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RoPE::COS_TABLE, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RoPE::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::RoPE::SIN_TABLE.

◆ visit() [34/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Select node)
override

Definition at line 934 of file ShapeValidator.cc.

935{
936 // TODO Shape validation of select
937}

◆ visit() [35/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Shape node)
override

Definition at line 970 of file ShapeValidator.cc.

971{
972 const auto &operands = _graph.operands();
973 const auto output_index{node.getOutputs().at(0)};
974 if (operands.at(output_index).info().isDynamic())
975 return;
976
977 [[maybe_unused]] const auto input_index{node.getInputs().at(0)};
978 OP_REQUIRES(operands.at(output_index).shape().rank() == 1);
979}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [36/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Softmax node)
override

Definition at line 218 of file ShapeValidator.cc.

219{
220 const auto &operands = _graph.operands();
221 const auto output_index{node.getOutputs().at(0)};
222 if (operands.at(output_index).info().isDynamic())
223 return;
224
225 const auto input_index{node.getInputs().at(0)};
226
227 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
228}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [37/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SpaceToBatchND node)
override

Definition at line 381 of file ShapeValidator.cc.

382{
383 const auto &operands = _graph.operands();
384 const auto ofm_index{node.getOutputs().at(0)};
385 if (operands.at(ofm_index).info().isDynamic())
386 return;
387
388 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
389 const auto block_size_index{
391 const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
392
393 const auto input_shape = operands.at(ifm_index).shape().asFeature();
394 const auto output_shape = operands.at(ofm_index).shape().asFeature();
395
396 // All requirement as per NNAPI specification.
397 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
398 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
399 OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
400 OP_REQUIRES(operands.at(paddings_index).shape().rank() == 2);
401
402 OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
403 OP_REQUIRES(operands.at(paddings_index).shape().dim(0) == 2);
404 OP_REQUIRES(operands.at(paddings_index).shape().dim(1) == 2);
405
406 OP_REQUIRES(input_shape.C == output_shape.C);
407}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::SpaceToBatchND::BLOCK_SIZE, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SpaceToBatchND::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SpaceToBatchND::PADDINGS.

◆ visit() [38/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SpaceToDepth node)
override

Definition at line 409 of file ShapeValidator.cc.

410{
411 const auto &operands = _graph.operands();
412 const auto ofm_index{node.getOutputs().at(0)};
413 if (operands.at(ofm_index).info().isDynamic())
414 return;
415
416 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
417
418 const auto input_shape = operands.at(ifm_index).shape().asFeature();
419 const auto output_shape = operands.at(ofm_index).shape().asFeature();
420 const auto block_size = node.param().block_size;
421
422 // All assertions as per NNAPI specification.
423 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
424 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
425 OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0));
426 OP_REQUIRES(input_shape.N == output_shape.N);
427 OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
428}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::SpaceToDepth::Param::block_size, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SpaceToDepth::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SpaceToDepth::param().

◆ visit() [39/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Split node)
override

Definition at line 951 of file ShapeValidator.cc.

952{
953 const auto &operands = _graph.operands();
954 const auto output_index{node.getOutputs().at(0)};
955 if (operands.at(output_index).info().isDynamic())
956 return;
957
958 const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
959 const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)};
960
961 const auto num_splits = node.param().num_splits;
962 const auto input_rank = operands.at(input_index).shape().rank();
963 auto axis = *reinterpret_cast<const int32_t *>(operands.at(axis_index).data()->base());
964 axis = axis < 0 ? axis + input_rank : axis;
965
966 OP_REQUIRES(axis >= 0 && axis < input_rank);
967 OP_REQUIRES(operands.at(input_index).shape().dim(axis) % num_splits == 0);
968}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Split::AXIS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Split::INPUT, onert::ir::operation::Split::Param::num_splits, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Split::param().

◆ visit() [40/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SquaredDifference node)
override

Definition at line 1017 of file ShapeValidator.cc.

1018{
1019 const auto &operands = _graph.operands();
1020 const auto output_index{node.getOutputs().at(0)};
1021 const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
1022 const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
1023
1024 // Check for dimension constraints
1025 if (operands.at(output_index).info().isDynamic())
1026 return;
1027
1028 auto output_shape = operands.at(output_index).shape();
1029 auto lhs_shape = operands.at(lhs_index).shape();
1030 auto rhs_shape = operands.at(rhs_index).shape();
1031 // Check for output rank
1032 OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
1033 auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
1034
1035 for (int idx = 1; idx <= min_rank; idx++)
1036 {
1037 int l_idx = lhs_shape.rank() - idx;
1038 int r_idx = rhs_shape.rank() - idx;
1039 int out_idx = output_shape.rank() - idx;
1040
1041 OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
1042
1043 auto l_dims = lhs_shape.dim(l_idx);
1044 auto r_dims = rhs_shape.dim(r_idx);
1045 auto out_dims = output_shape.dim(out_idx);
1046
1047 OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
1048 ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
1049 }
1050 auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
1051 for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
1052 {
1053 int out_idx = output_shape.rank() - idx;
1054 int tmp_idx = tmp_shape.rank() - idx;
1055
1056 OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
1057 (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
1058 }
1059}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SquaredDifference::LHS, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SquaredDifference::RHS.

◆ visit() [41/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::StridedSlice node)
override

Definition at line 939 of file ShapeValidator.cc.

940{
941 const auto &operands = _graph.operands();
942 const auto output_index{node.getOutputs().at(0)};
943 const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
944
945 if (operands.at(output_index).info().isDynamic())
946 return;
947
948 OP_REQUIRES(operands.at(input_index).shape().rank() <= 4);
949}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::StridedSlice::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [42/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Tile node)
override

Definition at line 1060 of file ShapeValidator.cc.

1061{
1062 const auto &operands = _graph.operands();
1063 const auto output_index{node.getOutputs().at(0)};
1064 if (operands.at(output_index).info().isDynamic())
1065 return;
1066
1067 const auto input_index{node.getInputs().at(0)};
1068 const auto multiple_index{node.getInputs().at(1)};
1069
1070 OP_REQUIRES(operands.at(multiple_index).shape().rank() == 1);
1071 OP_REQUIRES(operands.at(multiple_index).shape().dim(0) ==
1072 operands.at(input_index).shape().rank());
1073 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
1074}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [43/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Transpose node)
override

Definition at line 318 of file ShapeValidator.cc.

319{
320 const auto &operands = _graph.operands();
321 const auto output_index{node.getOutputs().at(0)};
322 if (operands.at(output_index).info().isDynamic())
323 return;
324
325 const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
326 const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
327
328 const auto &output_shape = operands.at(output_index).shape();
329 const auto &input_shape = operands.at(input_index).shape();
330
331 OP_REQUIRES(operands.at(perm_index).shape().num_elements() == 0 ||
332 input_shape.rank() ==
333 static_cast<int>(operands.at(perm_index).shape().num_elements()));
334 OP_REQUIRES(input_shape.rank() == output_shape.rank());
335}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Transpose::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::Transpose::PERMUTATION.

◆ visit() [44/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::TransposeConv node)
override

Definition at line 522 of file ShapeValidator.cc.

523{
524 // shape check
525 const auto &operands = _graph.operands();
526 const auto ofm_index{node.getOutputs().at(0)};
527
528 if (operands.at(ofm_index).info().isDynamic())
529 return;
530
531 const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
532 const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
533
534 // Only 4D tensors are supported
535 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
536 OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ifm_index).shape().rank());
537 OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ker_index).shape().rank());
538
539 const auto ofm_shape = operands.at(ofm_index).shape().asFeature();
540 const auto ifm_shape = operands.at(ifm_index).shape().asFeature();
541 // The kernel has only IHWO layout on frontend
542 // So ker_shape is treated here below
543 // I -> N
544 // H -> H
545 // W -> W
546 // O -> C
547 const auto ker_shape = operands.at(ker_index).shape().asFeature();
548
549 OP_REQUIRES(ifm_shape.N == ofm_shape.N);
550 OP_REQUIRES(ifm_shape.C == ker_shape.C);
551 OP_REQUIRES(ker_shape.N == ofm_shape.C);
552}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::TransposeConv::INPUT, onert::ir::operation::TransposeConv::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [45/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Unpack node)
override

Definition at line 897 of file ShapeValidator.cc.

898{
899 const auto &operands = _graph.operands();
900 const auto axis{node.param().axis};
901 const auto output_index{node.getInputs().at(0)};
902 if (operands.at(output_index).info().isDynamic())
903 return;
904
905 const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
906
907 const auto &input_shape = operands.at(input_index).shape();
908 const auto input_rank = static_cast<int32_t>(input_shape.rank());
909
910 OP_REQUIRES(axis >= -input_rank && axis < input_rank);
911}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Unpack::Param::axis, onert::ir::Operation::getInputs(), onert::ir::operation::Unpack::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Unpack::param().

◆ visit() [46/46]

void onert::compiler::ShapeValidator::visit ( const ir::operation::While node)
override

Definition at line 1011 of file ShapeValidator.cc.

1012{
1013 // This validator does not check shape. So checking isDynamic() is skipped.
1014 // TODO Add to validate with subgraphs
1015}

The documentation for this class was generated from the following files: