ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::compiler::ShapeValidator Class Reference

#include <ShapeValidator.h>

Collaboration diagram for onert::compiler::ShapeValidator:

Public Member Functions

 ShapeValidator (void)=delete
 
 ShapeValidator (const ir::Graph &graph)
 
 ShapeValidator (const ShapeValidator &)=delete
 
 ShapeValidator (ShapeValidator &&)=delete
 
 ~ShapeValidator ()=default
 
ShapeValidatoroperator= (const ShapeValidator &)=delete
 
ShapeValidatoroperator= (ShapeValidator &&)=delete
 
void operator() ()
 
void visit (const ir::operation::BatchMatMul &node) override
 
void visit (const ir::operation::BatchToSpaceND &node) override
 
void visit (const ir::operation::BCQFullyConnected &node) override
 
void visit (const ir::operation::BCQGather &node) override
 
void visit (const ir::operation::BCQUnembedding &node) override
 
void visit (const ir::operation::BroadcastTo &node) override
 
void visit (const ir::operation::Comparison &node) override
 
void visit (const ir::operation::Conv2D &node) override
 
void visit (const ir::operation::DepthToSpace &node) override
 
void visit (const ir::operation::DepthwiseConv2D &node) override
 
void visit (const ir::operation::DynamicUpdateSlice &node) override
 
void visit (const ir::operation::ElementwiseActivation &node) override
 
void visit (const ir::operation::ElementwiseBinary &node) override
 
void visit (const ir::operation::ElementwiseUnary &node) override
 
void visit (const ir::operation::EmbeddingLookup &node) override
 
void visit (const ir::operation::ExpandDims &node) override
 
void visit (const ir::operation::FullyConnected &node) override
 
void visit (const ir::operation::Gather &node) override
 
void visit (const ir::operation::HashtableLookup &node) override
 
void visit (const ir::operation::If &node) override
 
void visit (const ir::operation::InstanceNorm &node) override
 
void visit (const ir::operation::L2Normalization &node) override
 
void visit (const ir::operation::LogSoftmax &node) override
 
void visit (const ir::operation::LSTM &node) override
 
void visit (const ir::operation::Pack &node) override
 
void visit (const ir::operation::Pad &node) override
 
void visit (const ir::operation::Permute &node) override
 
void visit (const ir::operation::Pool2D &node) override
 
void visit (const ir::operation::Range &node) override
 
void visit (const ir::operation::Reduce &node) override
 
void visit (const ir::operation::ResizeBilinear &node) override
 
void visit (const ir::operation::Reverse &node) override
 
void visit (const ir::operation::RmsNorm &node) override
 
void visit (const ir::operation::RNN &node) override
 
void visit (const ir::operation::RoPE &node) override
 
void visit (const ir::operation::Select &node) override
 
void visit (const ir::operation::Shape &node) override
 
void visit (const ir::operation::Softmax &node) override
 
void visit (const ir::operation::SpaceToBatchND &node) override
 
void visit (const ir::operation::SpaceToDepth &node) override
 
void visit (const ir::operation::Split &node) override
 
void visit (const ir::operation::SquaredDifference &node) override
 
void visit (const ir::operation::StridedSlice &node) override
 
void visit (const ir::operation::Tile &node) override
 
void visit (const ir::operation::Transpose &node) override
 
void visit (const ir::operation::TransposeConv &node) override
 
void visit (const ir::operation::Unpack &node) override
 
void visit (const ir::operation::While &node) override
 
- Public Member Functions inherited from onert::ir::OperationVisitor
virtual ~OperationVisitor ()=default
 

Detailed Description

Definition at line 32 of file ShapeValidator.h.

Constructor & Destructor Documentation

◆ ShapeValidator() [1/4]

onert::compiler::ShapeValidator::ShapeValidator ( void  )
delete

◆ ShapeValidator() [2/4]

onert::compiler::ShapeValidator::ShapeValidator ( const ir::Graph graph)

Definition at line 35 of file ShapeValidator.cc.

35: _graph{graph} {}

◆ ShapeValidator() [3/4]

onert::compiler::ShapeValidator::ShapeValidator ( const ShapeValidator )
delete

◆ ShapeValidator() [4/4]

onert::compiler::ShapeValidator::ShapeValidator ( ShapeValidator &&  )
delete

◆ ~ShapeValidator()

onert::compiler::ShapeValidator::~ShapeValidator ( )
default

Member Function Documentation

◆ operator()()

void onert::compiler::ShapeValidator::operator() ( )

Definition at line 50 of file ShapeValidator.cc.

51{
52 _graph.operations().iterate(
53 [&](const ir::OperationIndex &, const ir::IOperation &node) { node.accept(*this); });
54}
const Operations & operations() const override
Definition Graph.h:105
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
::onert::util::Index< uint32_t, OperationIndexTag > OperationIndex
Definition Index.h:30

References onert::ir::IOperation::accept(), onert::util::ObjectManager< Index, Object >::iterate(), and onert::ir::Graph::operations().

◆ operator=() [1/2]

ShapeValidator & onert::compiler::ShapeValidator::operator= ( const ShapeValidator )
delete

◆ operator=() [2/2]

ShapeValidator & onert::compiler::ShapeValidator::operator= ( ShapeValidator &&  )
delete

◆ visit() [1/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BatchMatMul node)
override

Definition at line 56 of file ShapeValidator.cc.

57{
58 const auto &operands = _graph.operands();
59 const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
60 const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
61 const auto out_index{node.getOutputs().at(0)};
62
63 if (operands.at(out_index).info().isDynamic())
64 return;
65
66 OP_REQUIRES(operands.at(lhs_index).shape().rank() <= 4);
67 OP_REQUIRES(operands.at(rhs_index).shape().rank() <= 4);
68 OP_REQUIRES(operands.at(lhs_index).shape().rank() >= 2);
69 OP_REQUIRES(operands.at(rhs_index).shape().rank() >= 2);
70}
#define OP_REQUIRES(EXP)
const Operands & operands() const override
Definition Graph.h:103

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BatchMatMul::LHS, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::BatchMatMul::RHS.

◆ visit() [2/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BatchToSpaceND node)
override

Definition at line 72 of file ShapeValidator.cc.

73{
74 const auto &operands = _graph.operands();
75 const auto ofm_index{node.getOutputs().at(0)};
76 if (operands.at(ofm_index).info().isDynamic())
77 return;
78
79 const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
80 const auto block_size_index{
82
83 const auto input_shape = operands.at(ifm_index).shape().asFeature();
84 const auto output_shape = operands.at(ofm_index).shape().asFeature();
85
86 // All requirement as per NNAPI specification.
87 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
88 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
89 OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
90
91 OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
92
93 if (node.getInputs().size() != 2)
94 {
95 const auto crops_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::CROPS_DATA)};
96 OP_REQUIRES(operands.at(crops_index).shape().rank() == 2);
97 OP_REQUIRES(operands.at(crops_index).shape().dim(0) ==
98 (operands.at(ifm_index).shape().rank() - 2));
99 OP_REQUIRES(operands.at(crops_index).shape().dim(1) == 2);
100 }
101
102 OP_REQUIRES(input_shape.C == output_shape.C);
103}
const Object & at(const Index &index) const
Get the object that is associated with the given index.
const luci_interpreter::RuntimeShape output_shape

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::BatchToSpaceND::BLOCK_SIZE, onert::ir::operation::BatchToSpaceND::CROPS_DATA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BatchToSpaceND::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::OperandIndexSequence::size().

◆ visit() [3/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BCQFullyConnected node)
override

Definition at line 105 of file ShapeValidator.cc.

106{
107 const auto &operands = _graph.operands();
108 const auto ofm_index{node.getOutputs().at(0)};
109 if (operands.at(ofm_index).info().isDynamic())
110 return;
111
112 const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
113 const auto weight_scales_index{
115 const auto weight_binary_index{
117 const auto weight_cluster_index{
119 const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)};
120
121 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 2);
122 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 2);
123 OP_REQUIRES(operands.at(weight_scales_index).shape().rank() == 1);
124 OP_REQUIRES(operands.at(weight_binary_index).shape().rank() == 2);
125 OP_REQUIRES(operands.at(weight_cluster_index).shape().rank() == 2);
126
127 OP_REQUIRES(operands.at(ifm_index).shape().dim(1) == operands.at(ofm_index).shape().dim(1));
128
129 OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(0) > 0);
130 OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(1) == 2);
131
132 // more shape validation will be done inside kernel.
133
134 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
135}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::BCQFullyConnected::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BCQFullyConnected::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::BCQFullyConnected::WEIGHTS_BINARY, onert::ir::operation::BCQFullyConnected::WEIGHTS_CLUSTERS, and onert::ir::operation::BCQFullyConnected::WEIGHTS_SCALES.

◆ visit() [4/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BCQGather node)
override

Definition at line 137 of file ShapeValidator.cc.

138{
139 const auto &operands = _graph.operands();
140 const auto ofm_index{node.getOutputs().at(0)};
141 if (operands.at(ofm_index).info().isDynamic())
142 return;
143
144 const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
145 const auto input_binary_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
146 const auto input_scales_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_SCALES)};
147 const auto input_clusters_index{
149
150 OP_REQUIRES(operands.at(indices_index).shape().rank() <=
151 2); // TODO : support rank up to 4 or more
152 OP_REQUIRES(operands.at(input_binary_index).shape().rank() == 2);
153 OP_REQUIRES(operands.at(input_scales_index).shape().rank() == 1);
154 OP_REQUIRES(operands.at(input_clusters_index).shape().rank() == 2);
155
156 OP_REQUIRES(operands.at(input_clusters_index).shape().dim(0) > 0);
157 OP_REQUIRES(operands.at(input_clusters_index).shape().dim(1) == 2);
158
159 // more shape validation will be done inside kernel.
160}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BCQGather::INDICES, onert::ir::operation::BCQGather::INPUT_BINARY, onert::ir::operation::BCQGather::INPUT_CLUSTERS, onert::ir::operation::BCQGather::INPUT_SCALES, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [5/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BCQUnembedding node)
override

Definition at line 162 of file ShapeValidator.cc.

163{
164 const auto &operands = _graph.operands();
165 const auto ofm_index{node.getOutputs().at(0)};
166 if (operands.at(ofm_index).info().isDynamic())
167 return;
168
169 const auto ifm_index{node.getInputs().at(ir::operation::BCQUnembedding::Input::INPUT)};
170 const auto weight_scales_index{
172 const auto weight_binary_index{
174 const auto weight_cluster_index{
176 const auto bias_index{node.getInputs().at(ir::operation::BCQUnembedding::Input::BIAS)};
177
178 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 2);
179 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 2);
180 OP_REQUIRES(operands.at(weight_scales_index).shape().rank() == 1);
181 OP_REQUIRES(operands.at(weight_binary_index).shape().rank() == 2);
182 OP_REQUIRES(operands.at(weight_cluster_index).shape().rank() == 2);
183
184 OP_REQUIRES(operands.at(ifm_index).shape().dim(1) == operands.at(ofm_index).shape().dim(1));
185
186 OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(0) > 0);
187 OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(1) == 2);
188
189 // more shape validation will be done inside kernel.
190
191 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
192}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::BCQUnembedding::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BCQUnembedding::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::BCQUnembedding::WEIGHTS_BINARY, onert::ir::operation::BCQUnembedding::WEIGHTS_CLUSTERS, and onert::ir::operation::BCQUnembedding::WEIGHTS_SCALES.

◆ visit() [6/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BroadcastTo node)
override

Definition at line 194 of file ShapeValidator.cc.

195{
196 const auto &operands = _graph.operands();
197 const auto output_index{node.getOutputs().at(0)};
198 if (operands.at(output_index).info().isDynamic())
199 return;
200
201 const auto input_index{node.getInputs().at(ir::operation::BroadcastTo::Input::INPUT)};
202 const auto shape_index{node.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
203 const auto &input_shape = operands.at(input_index).shape();
204 const auto &output_shape_vec = operands.at(shape_index).asVector<int32_t>();
205 int input_num_dims = input_shape.rank();
206 int output_num_dims = output_shape_vec.size();
207 OP_REQUIRES(input_num_dims <= output_num_dims);
208
209 int extending_dims = output_num_dims - input_num_dims;
210 for (int idx = 0; idx < input_num_dims; ++idx)
211 {
212 OP_REQUIRES(input_shape.dim(idx) == 1 ||
213 input_shape.dim(idx) == output_shape_vec.at(extending_dims + idx));
214 }
215}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BroadcastTo::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::BroadcastTo::SHAPE.

◆ visit() [7/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Comparison node)
override

Definition at line 217 of file ShapeValidator.cc.

218{
219 // TODO Shape validation of comparison
220}

◆ visit() [8/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Conv2D node)
override

Definition at line 222 of file ShapeValidator.cc.

223{
224 const auto &operands = _graph.operands();
225 const auto ofm_index{node.getOutputs().at(0)};
226 if (operands.at(ofm_index).info().isDynamic())
227 return;
228
229 const auto ifm_index{node.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
230 const auto ker_index{node.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
231 const auto bias_index{node.getInputs().at(ir::operation::Conv2D::Input::BIAS)};
232
233 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
234 OP_REQUIRES(operands.at(ker_index).shape().rank() == 4);
235 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
236 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
237}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Conv2D::INPUT, onert::ir::operation::Conv2D::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [9/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::DepthToSpace node)
override

Definition at line 239 of file ShapeValidator.cc.

240{
241 const auto &operands = _graph.operands();
242 int32_t block_size = node.param().block_size;
243
244 // shape check
245 const auto output_index{node.getOutputs().at(0)};
246 if (operands.at(output_index).info().isDynamic())
247 return;
248
249 const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
250
251 const auto output_shape = operands.at(output_index).shape().asFeature();
252 const auto input_shape = operands.at(input_index).shape().asFeature();
253
254 OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
255 OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
256
257 {
258 OP_REQUIRES(output_shape.N == input_shape.N);
259 OP_REQUIRES(output_shape.H == input_shape.H * block_size);
260 OP_REQUIRES(output_shape.W == input_shape.W * block_size);
261 OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
262 OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
263 }
264}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::DepthToSpace::Param::block_size, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthToSpace::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::DepthToSpace::param().

◆ visit() [10/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::DepthwiseConv2D node)
override

Definition at line 266 of file ShapeValidator.cc.

267{
268 const auto &operands = _graph.operands();
269 const auto ofm_index{node.getOutputs().at(0)};
270 if (operands.at(ofm_index).info().isDynamic())
271 return;
272
273 const auto ifm_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::INPUT)};
274 const auto ker_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::KERNEL)};
275 const auto bias_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::BIAS)};
276
277 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
278 OP_REQUIRES(operands.at(ker_index).shape().rank() == 4);
279 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
280 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
281}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::DepthwiseConv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthwiseConv2D::INPUT, onert::ir::operation::DepthwiseConv2D::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [11/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::DynamicUpdateSlice node)
override

Definition at line 283 of file ShapeValidator.cc.

284{
285 const auto &operands = _graph.operands();
286 const auto output_index{node.getOutputs().at(0)};
287 if (operands.at(output_index).info().isDynamic())
288 return;
289
290 const auto operand_index{node.getInputs().at(ir::operation::DynamicUpdateSlice::Input::OPERAND)};
291 const auto update_index{node.getInputs().at(ir::operation::DynamicUpdateSlice::Input::UPDATE)};
292 const auto indices_index{node.getInputs().at(ir::operation::DynamicUpdateSlice::Input::INDICES)};
293
294 OP_REQUIRES(operands.at(indices_index).shape().rank() == 1);
295 OP_REQUIRES(operands.at(indices_index).shape().dim(0) ==
296 operands.at(operand_index).shape().rank());
297 OP_REQUIRES(operands.at(operand_index).shape().rank() ==
298 operands.at(update_index).shape().rank());
299 for (int i = 0; i < operands.at(operand_index).shape().rank(); i++)
300 {
301 OP_REQUIRES(operands.at(operand_index).shape().dim(i) >=
302 operands.at(update_index).shape().dim(i));
303 }
304 OP_REQUIRES(operands.at(operand_index).shape() == operands.at(output_index).shape());
305}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DynamicUpdateSlice::INDICES, OP_REQUIRES, onert::ir::operation::DynamicUpdateSlice::OPERAND, onert::ir::Graph::operands(), and onert::ir::operation::DynamicUpdateSlice::UPDATE.

◆ visit() [12/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseActivation node)
override

Definition at line 307 of file ShapeValidator.cc.

307{ checkUnaryOp(node); }

◆ visit() [13/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseBinary node)
override

Definition at line 309 of file ShapeValidator.cc.

310{
311 // TODO Shape validation of ElementwiseBinary
312}

◆ visit() [14/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseUnary node)
override

Definition at line 314 of file ShapeValidator.cc.

315{
316 const auto &operands = _graph.operands();
317 const auto output_index{node.getOutputs().at(0)};
318 const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
319
320 if (operands.at(output_index).info().isDynamic())
321 return;
322
323 OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
324}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ElementwiseUnary::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [15/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::EmbeddingLookup node)
override

Definition at line 326 of file ShapeValidator.cc.

327{
328 const auto &operands = _graph.operands();
329 const auto output_index{node.getOutputs().at(0)};
330 const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
331 const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
332
333 const auto &output_obj = operands.at(output_index);
334 const auto &lookups_obj = operands.at(lookups_index);
335 const auto &values_obj = operands.at(values_index);
336
337 // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
338 // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
339 {
340 if (operands.at(output_index).info().isDynamic())
341 return;
342
343 const auto &output_shape = output_obj.shape();
344 const auto &lookups_shape = lookups_obj.shape();
345 const auto &values_shape = values_obj.shape();
346
347 OP_REQUIRES(lookups_shape.rank() == 1);
348 OP_REQUIRES(values_shape.rank() >= 2);
349
350 // output should be a n-D tensor with the same rank and shape as the values tensor, except for
351 // the first dimension which has the same size as lookups' only dimension.
352 OP_REQUIRES(output_shape.rank() == values_shape.rank());
353 OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
354 for (int n = 1; n < output_shape.rank(); ++n)
355 {
356 OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
357 }
358 }
359}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::EmbeddingLookup::LOOKUPS, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::EmbeddingLookup::VALUES.

◆ visit() [16/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ExpandDims node)
override

Definition at line 361 of file ShapeValidator.cc.

362{
363 const auto &operands = _graph.operands();
364 const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
365
366 if (operands.at(axis_index).info().isDynamic())
367 return;
368 OP_REQUIRES(operands.at(axis_index).shape().rank() <= 1);
369}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::ExpandDims::AXIS, onert::ir::Operation::getInputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [17/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::FullyConnected node)
override

Definition at line 371 of file ShapeValidator.cc.

372{
373 const auto &operands = _graph.operands();
374 const auto ofm_index{node.getOutputs().at(0)};
375 if (operands.at(ofm_index).info().isDynamic())
376 return;
377
378 const auto ifm_index{node.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
379 const auto ker_index{node.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
380 const auto bias_index{node.getInputs().at(ir::operation::FullyConnected::Input::BIAS)};
381
382 OP_REQUIRES(operands.at(ifm_index).shape().rank() >= 2);
383 OP_REQUIRES(operands.at(ker_index).shape().rank() == 2);
384 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
385}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::FullyConnected::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::FullyConnected::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::FullyConnected::WEIGHT.

◆ visit() [18/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Gather node)
override

Definition at line 387 of file ShapeValidator.cc.

388{
389 const auto &operands = _graph.operands();
390 const auto ofm_index{node.getOutputs().at(0)};
391 if (operands.at(ofm_index).info().isDynamic())
392 return;
393
394 const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
395 const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
396
397 const auto &ifm_shape = operands.at(ifm_index).shape();
398 const auto &indices_shape = operands.at(indices_index).shape();
399 const auto &ofm_shape = operands.at(ofm_index).shape();
400
401 // Since gather implementation is general enough, we do not restrict max rank
402 OP_REQUIRES(ifm_shape.rank() + indices_shape.rank() - 1 == ofm_shape.rank());
403}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Gather::INDICES, onert::ir::operation::Gather::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [19/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::HashtableLookup node)
override

Definition at line 405 of file ShapeValidator.cc.

406{
407 const auto &operands = _graph.operands();
408 const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
409 const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
410 const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
411 const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
412
413 const auto &output_obj = operands.at(output_index);
414 const auto &lookups_obj = operands.at(lookups_index);
415 const auto &keys_obj = operands.at(keys_index);
416 const auto &values_obj = operands.at(values_index);
417
418 if (operands.at(output_index).info().isDynamic())
419 return;
420
421 const auto &output_shape = output_obj.shape();
422 const auto &lookups_shape = lookups_obj.shape();
423 const auto &keys_shape = keys_obj.shape();
424 const auto &values_shape = values_obj.shape();
425
426 OP_REQUIRES(values_shape.rank() == output_shape.rank());
427 OP_REQUIRES(lookups_shape.rank() == 1);
428 OP_REQUIRES(keys_shape.rank() == 1);
429 OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
430 OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
431}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::HashtableLookup::KEYS, onert::ir::operation::HashtableLookup::LOOKUPS, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::HashtableLookup::OUTPUT, output_shape, and onert::ir::operation::HashtableLookup::VALUES.

◆ visit() [20/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::If node)
override

Definition at line 433 of file ShapeValidator.cc.

434{
435 // TODO Add to validate with subgraphs
436}

◆ visit() [21/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::InstanceNorm node)
override

Definition at line 438 of file ShapeValidator.cc.

439{
440 const auto &operands = _graph.operands();
441 const auto ofm_index{node.getOutputs().at(0)};
442 if (operands.at(ofm_index).info().isDynamic())
443 return;
444
445 const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
446 const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
447 const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
448
449 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
450 OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
451 OP_REQUIRES(operands.at(gamma_index).shape().rank() == 1);
452 OP_REQUIRES(operands.at(beta_index).shape().rank() == 1);
453}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::InstanceNorm::BETA, onert::ir::operation::InstanceNorm::GAMMA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::InstanceNorm::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [22/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::L2Normalization node)
override

Definition at line 455 of file ShapeValidator.cc.

456{
457 const auto &operands = _graph.operands();
458 const auto ofm_index{node.getOutputs().at(0)};
459 if (operands.at(ofm_index).info().isDynamic())
460 return;
461
462 const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
463
464 auto ifm_shape = operands.at(ifm_index).shape();
465 auto ofm_shape = operands.at(ofm_index).shape();
466
467 OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
468
469 for (auto i = 0; i < ifm_shape.rank(); i++)
470 {
471 OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
472 }
473}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::L2Normalization::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [23/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::LogSoftmax node)
override

Definition at line 475 of file ShapeValidator.cc.

476{
477 const auto &operands = _graph.operands();
478 const auto output_index{node.getOutputs().at(0)};
479 if (operands.at(output_index).info().isDynamic())
480 return;
481
482 const auto input_index{node.getInputs().at(0)};
483
484 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
485}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [24/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::LSTM node)
override

Definition at line 487 of file ShapeValidator.cc.

488{
489 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
490 // TODO Support dynamic rnn
491 const auto &operands = _graph.operands();
492 const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
493 if (operands.at(output_index).info().isDynamic())
494 return;
495
496 const auto scratch_buffer_index{
497 node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; // Optional
498 const auto output_state_out_index{
499 node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; // Optional
500 const auto cell_state_out_index{
501 node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; // Optional
502
503 const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
504 const auto input_to_input_weights_index{
505 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // Optional
506 const auto input_to_forget_weights_index{
508 const auto input_to_cell_weights_index{
510 const auto input_to_output_weights_index{
512 const auto recurrent_to_input_weights_index{
513 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // Optional
514 const auto recurrent_to_forget_weights_index{
516 const auto recurrent_to_cell_weights_index{
518 const auto recurrent_to_output_weights_index{
520 const auto cell_to_input_weights_index{
521 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // Optional
522 const auto cell_to_forget_weights_index{
523 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // Optional
524 const auto cell_to_output_weights_index{
525 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // Optional
526 const auto input_gate_bias_index{
527 node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; // Optional
528 const auto forget_gate_bias_index{
530 const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
531 const auto output_gate_bias_index{
533 const auto projection_weights_index{
534 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // Optional
535 const auto projection_bias_index{
536 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // Optional
537 const auto output_state_in_index{
539 const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
540
541 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
542 for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
543 {
544 OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
545 operands.at(output_index).shape().dim(i));
546 }
547 OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
548 operands.at(output_index).shape().rank() == 3) &&
549 (operands.at(input_index).shape().rank() == 2 ||
550 operands.at(input_index).shape().rank() == 3) &&
551 (!operands.exist(input_to_input_weights_index) ||
552 operands.at(input_to_input_weights_index).shape().rank() == 2) &&
553 operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
554 operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
555 operands.at(input_to_output_weights_index).shape().rank() == 2 &&
556 (!operands.exist(recurrent_to_input_weights_index) ||
557 operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
558 operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
559 operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
560 operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
561 (!operands.exist(projection_weights_index) ||
562 operands.at(projection_weights_index).shape().rank() == 2) &&
563 operands.at(output_state_in_index).shape().rank() == 2 &&
564 operands.at(cell_state_in_index).shape().rank() == 2);
565
566 OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
567 operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
568 (!operands.exist(cell_to_forget_weights_index) ||
569 operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
570 (!operands.exist(cell_to_output_weights_index) ||
571 operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
572 (!operands.exist(input_gate_bias_index) ||
573 operands.at(input_gate_bias_index).shape().rank() == 1) &&
574 operands.at(forget_gate_bias_index).shape().rank() == 1 &&
575 operands.at(cell_bias_index).shape().rank() == 1 &&
576 operands.at(output_gate_bias_index).shape().rank() == 1 &&
577 (!operands.exist(projection_bias_index) ||
578 operands.at(projection_bias_index).shape().rank() == 1));
579
580 // CIFG assertion
581 OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
582 (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
583 operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
584 (!operands.exist(recurrent_to_input_weights_index) ||
585 (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
586 operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
587 (!operands.exist(input_gate_bias_index) ||
588 operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
589 (!operands.exist(cell_to_input_weights_index) ||
590 operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
591 ((operands.exist(input_to_input_weights_index) &&
592 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
593 operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
594 (operands.exist(recurrent_to_input_weights_index) &&
595 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
596 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
597 (operands.exist(input_gate_bias_index) &&
598 operands.at(input_gate_bias_index).shape().dim(0) != 0)));
599
600 // Peephole assertion
601 OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
602 operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
603 (!operands.exist(cell_to_output_weights_index) ||
604 operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
605 ((operands.exist(cell_to_forget_weights_index) &&
606 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
607 (operands.exist(cell_to_output_weights_index) &&
608 operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
609
610 bool has_input_to_input_weights =
611 operands.exist(input_to_input_weights_index) &&
612 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
613 operands.at(input_to_input_weights_index).shape().dim(1) != 0);
614 bool has_recurrent_to_input_weights =
615 operands.exist(recurrent_to_input_weights_index) &&
616 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
617 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
618 bool has_input_gate_bias =
619 operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
620 bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
621 operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
622 bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
623 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
624 bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
625 operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
626 bool has_projection_weights = operands.exist(projection_weights_index) &&
627 (operands.at(projection_weights_index).shape().dim(0) != 0 &&
628 operands.at(projection_weights_index).shape().dim(1) != 0);
629 bool has_projection_bias =
630 operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
631
632 // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
633 // true: no CIFG
634 // false: CIFG
635 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
636
637 // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
638 // true: peephole
639 // false: no peephole
640 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
641
642 // NOTE The projection weights may have data but the projection bias may not.
643 bool has_projection_param = has_projection_weights;
644
645 const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
646 ? operands.at(input_index).shape().dim(1)
647 : operands.at(input_index).shape().dim(0);
648 OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
649 batch_size == operands.at(cell_state_in_index).shape().dim(0));
650
651 const auto input_size =
652 operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
653 OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
654 input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
655 input_size == operands.at(input_to_output_weights_index).shape().dim(1));
656
657 const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
658 OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
659 num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
660 num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
661 num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
662 num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
663 num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
664 num_units == operands.at(cell_bias_index).shape().dim(0) &&
665 num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
666 num_units == operands.at(cell_state_in_index).shape().dim(1));
667
668 const auto output_size =
669 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
670 OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
671 output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
672 output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
673 output_size == operands.at(output_state_in_index).shape().dim(1));
674
675 if (has_cifg_param)
676 {
677 OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
679 num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
680 num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
681 ((operands.exist(cell_to_input_weights_index) &&
682 num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
683 (!operands.exist(cell_to_input_weights_index) ||
684 operands.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
685 num_units == operands.at(input_gate_bias_index).shape().dim(0));
686 OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
687 OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
688 has_input_gate_bias);
689 if (has_cell_to_input_weights)
690 {
691 // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
692 OP_REQUIRES(has_peephole_param);
693 }
694 if (operands.exist(scratch_buffer_index))
695 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
696 }
697 else
698 {
699 if (operands.exist(scratch_buffer_index))
700 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
701 }
702
703 if (has_peephole_param)
704 {
705 OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
706 num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
707 (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
708 operands.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
709 }
710
711 if (has_projection_param)
712 {
713 OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
714 OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
715 if (has_projection_bias)
716 {
717 OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
718 }
719 }
720
721 if (operands.exist(scratch_buffer_index))
722 {
723 OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
724 OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
725 }
726
727 if (operands.exist(output_state_out_index))
728 {
729 OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
730 OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
731 OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
732 }
733
734 if (operands.exist(cell_state_out_index))
735 {
736 OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
737 OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
738 OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
739 }
740}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::LSTM::CELL_BIAS, onert::ir::operation::LSTM::CELL_STATE_IN, onert::ir::operation::LSTM::CELL_STATE_OUT, onert::ir::operation::LSTM::CELL_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::CELL_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::CELL_TO_OUTPUT_WEIGHTS, onert::ir::operation::LSTM::FORGET_GATE_BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::LSTM::INPUT, onert::ir::operation::LSTM::INPUT_GATE_BIAS, onert::ir::operation::LSTM::INPUT_TO_CELL_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_OUTPUT_WEIGHTS, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::LSTM::OUTPUT, onert::ir::operation::LSTM::OUTPUT_GATE_BIAS, onert::ir::operation::LSTM::OUTPUT_STATE_IN, onert::ir::operation::LSTM::OUTPUT_STATE_OUT, onert::ir::operation::LSTM::param(), onert::ir::operation::LSTM::PROJECTION_BIAS, onert::ir::operation::LSTM::PROJECTION_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_CELL_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_OUTPUT_WEIGHTS, onert::ir::operation::LSTM::SCRATCH_BUFFER, and onert::ir::operation::LSTM::Param::time_major.

◆ visit() [25/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pack node)
override

Definition at line 742 of file ShapeValidator.cc.

743{
744 const auto &operands = _graph.operands();
745 const auto axis{node.param().axis};
746 const auto output_index{node.getOutputs().at(0)};
747 if (operands.at(output_index).info().isDynamic())
748 return;
749
750 // shape check
751 const auto &output_shape = operands.at(output_index).shape();
752 const auto output_rank = static_cast<int32_t>(output_shape.rank());
753
754 const auto input1_index{node.getInputs().at(0)};
755 const auto &input_shape = operands.at(input1_index).shape();
756
757 OP_REQUIRES(axis >= -output_rank && axis < output_rank);
758 for (const auto &index : node.getInputs())
759 {
760 OP_REQUIRES(input_shape == operands.at(index).shape());
761 }
762}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Pack::Param::axis, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::Pack::param().

◆ visit() [26/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pad node)
override

Definition at line 764 of file ShapeValidator.cc.

765{
766 const auto &operands = _graph.operands();
767 const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
768 OP_REQUIRES(operands.at(pad_index).typeInfo().type() == ir::DataType::INT32);
769
770 const auto output_index{node.getInputs().at(0)};
771 if (operands.at(output_index).info().isDynamic())
772 return;
773
774 const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
775
776 const auto &pad_shape = operands.at(pad_index).shape();
777 const auto input_rank = static_cast<int32_t>(operands.at(input_index).shape().rank());
778
779 OP_REQUIRES(pad_shape.rank() == 2);
780 OP_REQUIRES(pad_shape.dim(0) == input_rank);
781 OP_REQUIRES(pad_shape.dim(1) == 2);
782 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
783}
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
uint32_t rank(void) const
Definition TensorShape.h:35
loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleNode *paddings)

References onert::ir::OperandIndexSequence::at(), loco::TensorShape::dim(), onert::ir::Operation::getInputs(), onert::ir::operation::Pad::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::Pad::PAD, and loco::TensorShape::rank().

◆ visit() [27/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Permute node)
override

Definition at line 785 of file ShapeValidator.cc.

786{
787 const auto &operands = _graph.operands();
788 const auto output_index{node.getOutputs().at(0)};
789 if (operands.at(output_index).info().isDynamic())
790 return;
791
792 const auto input_index{node.getInputs().at(0)};
793
794 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
795}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [28/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pool2D node)
override

Definition at line 797 of file ShapeValidator.cc.

798{
799 const auto &operands = _graph.operands();
800 const auto ofm_index{node.getOutputs().at(0)};
801 if (operands.at(ofm_index).info().isDynamic())
802 return;
803
804 const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
805
806 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
807}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Pool2D::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [29/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Range node)
override

Definition at line 809 of file ShapeValidator.cc.

810{
811 const auto &operands = _graph.operands();
812 const auto output_index{node.getOutputs().at(0)};
813 const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
814 const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
815 const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
816
817 // Check for dimension constraints
818 if (operands.at(output_index).info().isDynamic())
819 return;
820
821 OP_REQUIRES(operands.at(start_index).shape().rank() == 0);
822 OP_REQUIRES(operands.at(limit_index).shape().rank() == 0);
823 OP_REQUIRES(operands.at(delta_index).shape().rank() == 0);
824}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Range::DELTA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Range::LIMIT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Range::START.

◆ visit() [30/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Reduce node)
override

Definition at line 826 of file ShapeValidator.cc.

827{
828 const auto &operands = _graph.operands();
829 const auto output_index{node.getOutputs().at(0)};
830 if (operands.at(output_index).info().isDynamic())
831 return;
832
833 const auto &input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
834 const auto &input_shape = operands.at(input_index).shape();
835 const auto &output_shape = operands.at(output_index).shape();
836
837 OP_REQUIRES(input_shape.rank() <= 4);
838 OP_REQUIRES(output_shape.rank() <= input_shape.rank());
839
840 // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
841 // supports cases reducing height and width or reducing depth.
842 // TODO We have to support all cases of dimensions up to 4.
843 // For correct permuting, we have to set output's shape to be equal in dimension position of the
844 // input. But the positions of the same dimensions in the input and output may be set differently.
845 // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
846 // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
847 // extend it in 4 dimensions, it should be {1,1,3,5}.
848 // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
849 // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
850 // next operation is not desired.
851 if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
852 {
853 if (output_shape.rank() == 2)
854 {
855 // Reducing HW
856 OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
857 input_shape.dim(3) == output_shape.dim(1));
858 }
859 else if (output_shape.rank() == 3)
860 {
861 // Reducing C or
862 // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
864 (input_shape.dim(0) == output_shape.dim(0) && input_shape.dim(1) == output_shape.dim(1) &&
865 input_shape.dim(2) == output_shape.dim(2)) ||
866 (input_shape.dim(0) == output_shape.dim(0) &&
867 (input_shape.dim(1) == output_shape.dim(1) || input_shape.dim(2) == output_shape.dim(1)) &&
868 input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
869 }
870 }
871}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reduce::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and output_shape.

◆ visit() [31/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ResizeBilinear node)
override

Definition at line 873 of file ShapeValidator.cc.

874{
875 const auto &operands = _graph.operands();
876 const auto output_index{node.getOutputs().at(0)};
877 const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
878
879 if (operands.at(output_index).info().isDynamic())
880 {
881 return;
882 }
883 OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
884 OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
885}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ResizeBilinear::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [32/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Reverse node)
override

Definition at line 887 of file ShapeValidator.cc.

888{
889 const auto &operands = _graph.operands();
890 const auto output_index{node.getOutputs().at(0)};
891 const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
892
893 if (operands.at(output_index).info().isDynamic())
894 return;
895 OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
896}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reverse::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [33/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RmsNorm node)
override

Definition at line 898 of file ShapeValidator.cc.

899{
900 const auto &operands = _graph.operands();
901 const auto ofm_index{node.getOutputs().at(0)};
902 if (operands.at(ofm_index).info().isDynamic())
903 return;
904
905 const auto ifm_index{node.getInputs().at(ir::operation::RmsNorm::Input::INPUT)};
906 const auto gamma_index{node.getInputs().at(ir::operation::RmsNorm::Input::GAMMA)};
907
908 const auto &ifm_shape = operands.at(ifm_index).shape();
909 const auto &ofm_shape = operands.at(ofm_index).shape();
910 const auto &gamma_shape = operands.at(gamma_index).shape();
911
912 OP_REQUIRES(ifm_shape.rank() == 3 || ifm_shape.rank() == 4);
913 OP_REQUIRES(ifm_shape == ofm_shape);
914 OP_REQUIRES(gamma_shape.rank() == 1);
915 OP_REQUIRES((gamma_shape.dim(0) == 1) ||
916 (gamma_shape.dim(0) == ifm_shape.dim(ifm_shape.rank() - 1)));
917}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RmsNorm::GAMMA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RmsNorm::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [34/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RNN node)
override

Definition at line 919 of file ShapeValidator.cc.

920{
921 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
922 // TODO Support dynamic rnn
923 const auto &operands = _graph.operands();
924 const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
925 if (operands.at(output_index).info().isDynamic())
926 return;
927
928 const auto hidden_state_out_index{
930
931 const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
932 const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
933 const auto recurrent_weights_index{
935 const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
936 const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
937
938 const auto batch_size = operands.at(output_index).shape().dim(0);
939 const auto num_units = operands.at(output_index).shape().dim(1);
940
941 OP_REQUIRES(operands.at(output_index).shape().rank() == 2 &&
942 operands.at(hidden_state_out_index).shape().rank() == 2 &&
943 operands.at(input_index).shape().rank() == 2 &&
944 operands.at(weights_index).shape().rank() == 2 &&
945 operands.at(recurrent_weights_index).shape().rank() == 2 &&
946 operands.at(hidden_state_in_index).shape().rank() == 2);
947 OP_REQUIRES(operands.at(bias_index).shape().rank() == 1);
948
949 OP_REQUIRES(batch_size == operands.at(input_index).shape().dim(0) &&
950 batch_size == operands.at(hidden_state_in_index).shape().dim(0) &&
951 batch_size == operands.at(hidden_state_out_index).shape().dim(0));
952 OP_REQUIRES(operands.at(input_index).shape().dim(1) == operands.at(weights_index).shape().dim(1));
953
954 OP_REQUIRES(num_units == operands.at(weights_index).shape().dim(0) &&
955 num_units == operands.at(recurrent_weights_index).shape().dim(0) &&
956 num_units == operands.at(bias_index).shape().dim(0));
957 OP_REQUIRES(num_units == operands.at(output_index).shape().dim(1) &&
958 num_units == operands.at(recurrent_weights_index).shape().dim(1) &&
959 num_units == operands.at(hidden_state_in_index).shape().dim(1) &&
960 num_units == operands.at(hidden_state_out_index).shape().dim(1));
961}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RNN::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RNN::HIDDEN_STATE_IN, onert::ir::operation::RNN::HIDDEN_STATE_OUT, onert::ir::operation::RNN::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::RNN::OUTPUT, onert::ir::operation::RNN::RECURRENT_WEIGHTS, and onert::ir::operation::RNN::WEIGHTS.

◆ visit() [35/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RoPE node)
override

Definition at line 963 of file ShapeValidator.cc.

964{
965 const auto &operands = _graph.operands();
966 const auto ofm_index{node.getOutputs().at(0)};
967 if (operands.at(ofm_index).info().isDynamic())
968 return;
969
970 const auto ifm_index{node.getInputs().at(ir::operation::RoPE::Input::INPUT)};
971 const auto sin_table_index{node.getInputs().at(ir::operation::RoPE::Input::SIN_TABLE)};
972 const auto cos_table_index{node.getInputs().at(ir::operation::RoPE::Input::COS_TABLE)};
973
974 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
975 OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
976 OP_REQUIRES(operands.at(sin_table_index).shape().rank() == 4);
977 OP_REQUIRES(operands.at(cos_table_index).shape().rank() == 4);
978}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RoPE::COS_TABLE, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RoPE::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::RoPE::SIN_TABLE.

◆ visit() [36/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Select node)
override

Definition at line 980 of file ShapeValidator.cc.

981{
982 // TODO Shape validation of select
983}

◆ visit() [37/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Shape node)
override

Definition at line 985 of file ShapeValidator.cc.

986{
987 const auto &operands = _graph.operands();
988 const auto output_index{node.getOutputs().at(0)};
989 if (operands.at(output_index).info().isDynamic())
990 return;
991
992 [[maybe_unused]] const auto input_index{node.getInputs().at(0)};
993 OP_REQUIRES(operands.at(output_index).shape().rank() == 1);
994}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [38/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Softmax node)
override

Definition at line 996 of file ShapeValidator.cc.

997{
998 const auto &operands = _graph.operands();
999 const auto output_index{node.getOutputs().at(0)};
1000 if (operands.at(output_index).info().isDynamic())
1001 return;
1002
1003 const auto input_index{node.getInputs().at(0)};
1004
1005 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
1006}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [39/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SpaceToBatchND node)
override

Definition at line 1008 of file ShapeValidator.cc.

1009{
1010 const auto &operands = _graph.operands();
1011 const auto ofm_index{node.getOutputs().at(0)};
1012 if (operands.at(ofm_index).info().isDynamic())
1013 return;
1014
1015 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
1016 const auto block_size_index{
1018 const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
1019
1020 const auto input_shape = operands.at(ifm_index).shape().asFeature();
1021 const auto output_shape = operands.at(ofm_index).shape().asFeature();
1022
1023 // All requirement as per NNAPI specification.
1024 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
1025 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
1026 OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
1027 OP_REQUIRES(operands.at(paddings_index).shape().rank() == 2);
1028
1029 OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
1030 OP_REQUIRES(operands.at(paddings_index).shape().dim(0) == 2);
1031 OP_REQUIRES(operands.at(paddings_index).shape().dim(1) == 2);
1032
1033 OP_REQUIRES(input_shape.C == output_shape.C);
1034}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::SpaceToBatchND::BLOCK_SIZE, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SpaceToBatchND::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SpaceToBatchND::PADDINGS.

◆ visit() [40/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SpaceToDepth node)
override

Definition at line 1036 of file ShapeValidator.cc.

1037{
1038 const auto &operands = _graph.operands();
1039 const auto ofm_index{node.getOutputs().at(0)};
1040 if (operands.at(ofm_index).info().isDynamic())
1041 return;
1042
1043 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
1044
1045 const auto input_shape = operands.at(ifm_index).shape().asFeature();
1046 const auto output_shape = operands.at(ofm_index).shape().asFeature();
1047 const auto block_size = node.param().block_size;
1048
1049 // All assertions as per NNAPI specification.
1050 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
1051 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
1052 OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0));
1053 OP_REQUIRES(input_shape.N == output_shape.N);
1054 OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
1055}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::SpaceToDepth::Param::block_size, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SpaceToDepth::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SpaceToDepth::param().

◆ visit() [41/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Split node)
override

Definition at line 1057 of file ShapeValidator.cc.

1058{
1059 const auto &operands = _graph.operands();
1060 const auto output_index{node.getOutputs().at(0)};
1061 if (operands.at(output_index).info().isDynamic())
1062 return;
1063
1064 const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
1065 const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)};
1066
1067 const auto num_splits = node.param().num_splits;
1068 const auto input_rank = operands.at(input_index).shape().rank();
1069 auto axis = *reinterpret_cast<const int32_t *>(operands.at(axis_index).data()->base());
1070 axis = axis < 0 ? axis + input_rank : axis;
1071
1072 OP_REQUIRES(axis >= 0 && axis < input_rank);
1073 OP_REQUIRES(operands.at(input_index).shape().dim(axis) % num_splits == 0);
1074}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Split::AXIS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Split::INPUT, onert::ir::operation::Split::Param::num_splits, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Split::param().

◆ visit() [42/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SquaredDifference node)
override

Definition at line 1076 of file ShapeValidator.cc.

1077{
1078 const auto &operands = _graph.operands();
1079 const auto output_index{node.getOutputs().at(0)};
1080 const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
1081 const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
1082
1083 // Check for dimension constraints
1084 if (operands.at(output_index).info().isDynamic())
1085 return;
1086
1087 auto output_shape = operands.at(output_index).shape();
1088 auto lhs_shape = operands.at(lhs_index).shape();
1089 auto rhs_shape = operands.at(rhs_index).shape();
1090 // Check for output rank
1091 OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
1092 auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
1093
1094 for (int idx = 1; idx <= min_rank; idx++)
1095 {
1096 int l_idx = lhs_shape.rank() - idx;
1097 int r_idx = rhs_shape.rank() - idx;
1098 int out_idx = output_shape.rank() - idx;
1099
1100 OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
1101
1102 auto l_dims = lhs_shape.dim(l_idx);
1103 auto r_dims = rhs_shape.dim(r_idx);
1104 auto out_dims = output_shape.dim(out_idx);
1105
1106 OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
1107 ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
1108 }
1109 auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
1110 for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
1111 {
1112 int out_idx = output_shape.rank() - idx;
1113 int tmp_idx = tmp_shape.rank() - idx;
1114
1115 OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
1116 (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
1117 }
1118}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SquaredDifference::LHS, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SquaredDifference::RHS.

◆ visit() [43/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::StridedSlice node)
override

Definition at line 1120 of file ShapeValidator.cc.

1121{
1122 const auto &operands = _graph.operands();
1123 const auto output_index{node.getOutputs().at(0)};
1124 const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
1125
1126 if (operands.at(output_index).info().isDynamic())
1127 return;
1128
1129 OP_REQUIRES(operands.at(input_index).shape().rank() <= 5);
1130}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::StridedSlice::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [44/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Tile node)
override

Definition at line 1132 of file ShapeValidator.cc.

1133{
1134 const auto &operands = _graph.operands();
1135 const auto output_index{node.getOutputs().at(0)};
1136 if (operands.at(output_index).info().isDynamic())
1137 return;
1138
1139 const auto input_index{node.getInputs().at(0)};
1140 const auto multiple_index{node.getInputs().at(1)};
1141
1142 OP_REQUIRES(operands.at(multiple_index).shape().rank() == 1);
1143 OP_REQUIRES(operands.at(multiple_index).shape().dim(0) ==
1144 operands.at(input_index).shape().rank());
1145 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
1146}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [45/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Transpose node)
override

Definition at line 1148 of file ShapeValidator.cc.

1149{
1150 const auto &operands = _graph.operands();
1151 const auto output_index{node.getOutputs().at(0)};
1152 if (operands.at(output_index).info().isDynamic())
1153 return;
1154
1155 const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
1156 const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
1157
1158 const auto &output_shape = operands.at(output_index).shape();
1159 const auto &input_shape = operands.at(input_index).shape();
1160
1161 OP_REQUIRES(operands.at(perm_index).shape().num_elements() == 0 ||
1162 input_shape.rank() ==
1163 static_cast<int>(operands.at(perm_index).shape().num_elements()));
1164 OP_REQUIRES(input_shape.rank() == output_shape.rank());
1165}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Transpose::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::Transpose::PERMUTATION.

◆ visit() [46/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::TransposeConv node)
override

Definition at line 1167 of file ShapeValidator.cc.

1168{
1169 // shape check
1170 const auto &operands = _graph.operands();
1171 const auto ofm_index{node.getOutputs().at(0)};
1172
1173 if (operands.at(ofm_index).info().isDynamic())
1174 return;
1175
1176 const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
1177 const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
1178
1179 // Only 4D tensors are supported
1180 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
1181 OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ifm_index).shape().rank());
1182 OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ker_index).shape().rank());
1183
1184 const auto ofm_shape = operands.at(ofm_index).shape().asFeature();
1185 const auto ifm_shape = operands.at(ifm_index).shape().asFeature();
1186 // The kernel has only IHWO layout on frontend
1187 // So ker_shape is treated here below
1188 // I -> N
1189 // H -> H
1190 // W -> W
1191 // O -> C
1192 const auto ker_shape = operands.at(ker_index).shape().asFeature();
1193
1194 OP_REQUIRES(ifm_shape.N == ofm_shape.N);
1195 OP_REQUIRES(ifm_shape.C == ker_shape.C);
1196 OP_REQUIRES(ker_shape.N == ofm_shape.C);
1197}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::TransposeConv::INPUT, onert::ir::operation::TransposeConv::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [47/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Unpack node)
override

Definition at line 1199 of file ShapeValidator.cc.

1200{
1201 const auto &operands = _graph.operands();
1202 const auto axis{node.param().axis};
1203 const auto output_index{node.getInputs().at(0)};
1204 if (operands.at(output_index).info().isDynamic())
1205 return;
1206
1207 const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
1208
1209 const auto &input_shape = operands.at(input_index).shape();
1210 const auto input_rank = static_cast<int32_t>(input_shape.rank());
1211
1212 OP_REQUIRES(axis >= -input_rank && axis < input_rank);
1213}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Unpack::Param::axis, onert::ir::Operation::getInputs(), onert::ir::operation::Unpack::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Unpack::param().

◆ visit() [48/48]

void onert::compiler::ShapeValidator::visit ( const ir::operation::While node)
override

Definition at line 1215 of file ShapeValidator.cc.

1216{
1217 // This validator does not check shape. So checking isDynamic() is skipped.
1218 // TODO Add to validate with subgraphs
1219}

The documentation for this class was generated from the following files: