ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::compiler::ShapeValidator Class Reference

#include <ShapeValidator.h>

Collaboration diagram for onert::compiler::ShapeValidator:

Public Member Functions

 ShapeValidator (void)=delete
 
 ShapeValidator (const ir::Graph &graph)
 
 ShapeValidator (const ShapeValidator &)=delete
 
 ShapeValidator (ShapeValidator &&)=delete
 
 ~ShapeValidator ()=default
 
ShapeValidatoroperator= (const ShapeValidator &)=delete
 
ShapeValidatoroperator= (ShapeValidator &&)=delete
 
void operator() ()
 
void visit (const ir::operation::BatchMatMul &node) override
 
void visit (const ir::operation::BatchToSpaceND &node) override
 
void visit (const ir::operation::BCQFullyConnected &node) override
 
void visit (const ir::operation::BCQGather &node) override
 
void visit (const ir::operation::BroadcastTo &node) override
 
void visit (const ir::operation::Comparison &node) override
 
void visit (const ir::operation::Conv2D &node) override
 
void visit (const ir::operation::DepthToSpace &node) override
 
void visit (const ir::operation::DepthwiseConv2D &node) override
 
void visit (const ir::operation::DynamicUpdateSlice &node) override
 
void visit (const ir::operation::ElementwiseActivation &node) override
 
void visit (const ir::operation::ElementwiseBinary &node) override
 
void visit (const ir::operation::ElementwiseUnary &node) override
 
void visit (const ir::operation::EmbeddingLookup &node) override
 
void visit (const ir::operation::ExpandDims &node) override
 
void visit (const ir::operation::FullyConnected &node) override
 
void visit (const ir::operation::Gather &node) override
 
void visit (const ir::operation::HashtableLookup &node) override
 
void visit (const ir::operation::If &node) override
 
void visit (const ir::operation::InstanceNorm &node) override
 
void visit (const ir::operation::L2Normalization &node) override
 
void visit (const ir::operation::LogSoftmax &node) override
 
void visit (const ir::operation::LSTM &node) override
 
void visit (const ir::operation::Pack &node) override
 
void visit (const ir::operation::Pad &node) override
 
void visit (const ir::operation::Permute &node) override
 
void visit (const ir::operation::Pool2D &node) override
 
void visit (const ir::operation::Range &node) override
 
void visit (const ir::operation::Reduce &node) override
 
void visit (const ir::operation::ResizeBilinear &node) override
 
void visit (const ir::operation::Reverse &node) override
 
void visit (const ir::operation::RmsNorm &node) override
 
void visit (const ir::operation::RNN &node) override
 
void visit (const ir::operation::RoPE &node) override
 
void visit (const ir::operation::Select &node) override
 
void visit (const ir::operation::Shape &node) override
 
void visit (const ir::operation::Softmax &node) override
 
void visit (const ir::operation::SpaceToBatchND &node) override
 
void visit (const ir::operation::SpaceToDepth &node) override
 
void visit (const ir::operation::Split &node) override
 
void visit (const ir::operation::SquaredDifference &node) override
 
void visit (const ir::operation::StridedSlice &node) override
 
void visit (const ir::operation::Tile &node) override
 
void visit (const ir::operation::Transpose &node) override
 
void visit (const ir::operation::TransposeConv &node) override
 
void visit (const ir::operation::Unpack &node) override
 
void visit (const ir::operation::While &node) override
 
- Public Member Functions inherited from onert::ir::OperationVisitor
virtual ~OperationVisitor ()=default
 

Detailed Description

Definition at line 32 of file ShapeValidator.h.

Constructor & Destructor Documentation

◆ ShapeValidator() [1/4]

onert::compiler::ShapeValidator::ShapeValidator ( void  )
delete

◆ ShapeValidator() [2/4]

onert::compiler::ShapeValidator::ShapeValidator ( const ir::Graph graph)

Definition at line 35 of file ShapeValidator.cc.

35: _graph{graph} {}

◆ ShapeValidator() [3/4]

onert::compiler::ShapeValidator::ShapeValidator ( const ShapeValidator )
delete

◆ ShapeValidator() [4/4]

onert::compiler::ShapeValidator::ShapeValidator ( ShapeValidator &&  )
delete

◆ ~ShapeValidator()

onert::compiler::ShapeValidator::~ShapeValidator ( )
default

Member Function Documentation

◆ operator()()

void onert::compiler::ShapeValidator::operator() ( )

Definition at line 50 of file ShapeValidator.cc.

51{
52 _graph.operations().iterate(
53 [&](const ir::OperationIndex &, const ir::IOperation &node) { node.accept(*this); });
54}
const Operations & operations() const override
Definition Graph.h:105
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
::onert::util::Index< uint32_t, OperationIndexTag > OperationIndex
Definition Index.h:30

References onert::ir::IOperation::accept(), onert::util::ObjectManager< Index, Object >::iterate(), and onert::ir::Graph::operations().

◆ operator=() [1/2]

ShapeValidator & onert::compiler::ShapeValidator::operator= ( const ShapeValidator )
delete

◆ operator=() [2/2]

ShapeValidator & onert::compiler::ShapeValidator::operator= ( ShapeValidator &&  )
delete

◆ visit() [1/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BatchMatMul node)
override

Definition at line 56 of file ShapeValidator.cc.

57{
58 const auto &operands = _graph.operands();
59 const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
60 const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
61 const auto out_index{node.getOutputs().at(0)};
62
63 if (operands.at(out_index).info().isDynamic())
64 return;
65
66 OP_REQUIRES(operands.at(lhs_index).shape().rank() <= 4);
67 OP_REQUIRES(operands.at(rhs_index).shape().rank() <= 4);
68 OP_REQUIRES(operands.at(lhs_index).shape().rank() >= 2);
69 OP_REQUIRES(operands.at(rhs_index).shape().rank() >= 2);
70}
#define OP_REQUIRES(EXP)
const Operands & operands() const override
Definition Graph.h:103

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BatchMatMul::LHS, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::BatchMatMul::RHS.

◆ visit() [2/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BatchToSpaceND node)
override

Definition at line 72 of file ShapeValidator.cc.

73{
74 const auto &operands = _graph.operands();
75 const auto ofm_index{node.getOutputs().at(0)};
76 if (operands.at(ofm_index).info().isDynamic())
77 return;
78
79 const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
80 const auto block_size_index{
82
83 const auto input_shape = operands.at(ifm_index).shape().asFeature();
84 const auto output_shape = operands.at(ofm_index).shape().asFeature();
85
86 // All requirement as per NNAPI specification.
87 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
88 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
89 OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
90
91 OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
92
93 if (node.getInputs().size() != 2)
94 {
95 const auto crops_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::CROPS_DATA)};
96 OP_REQUIRES(operands.at(crops_index).shape().rank() == 2);
97 OP_REQUIRES(operands.at(crops_index).shape().dim(0) ==
98 (operands.at(ifm_index).shape().rank() - 2));
99 OP_REQUIRES(operands.at(crops_index).shape().dim(1) == 2);
100 }
101
102 OP_REQUIRES(input_shape.C == output_shape.C);
103}
const Object & at(const Index &index) const
Get the object that is associated with the given index.
const luci_interpreter::RuntimeShape output_shape

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::BatchToSpaceND::BLOCK_SIZE, onert::ir::operation::BatchToSpaceND::CROPS_DATA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BatchToSpaceND::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::OperandIndexSequence::size().

◆ visit() [3/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BCQFullyConnected node)
override

Definition at line 105 of file ShapeValidator.cc.

106{
107 const auto &operands = _graph.operands();
108 const auto ofm_index{node.getOutputs().at(0)};
109 if (operands.at(ofm_index).info().isDynamic())
110 return;
111
112 const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
113 const auto weight_scales_index{
115 const auto weight_binary_index{
117 const auto weight_cluster_index{
119 const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)};
120
121 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 2);
122 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 2);
123 OP_REQUIRES(operands.at(weight_scales_index).shape().rank() == 1);
124 OP_REQUIRES(operands.at(weight_binary_index).shape().rank() == 2);
125 OP_REQUIRES(operands.at(weight_cluster_index).shape().rank() == 2);
126
127 OP_REQUIRES(operands.at(ifm_index).shape().dim(1) == operands.at(ofm_index).shape().dim(1));
128
129 OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(0) > 0);
130 OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(1) == 2);
131
132 // more shape validation will be done inside kernel.
133
134 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
135}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::BCQFullyConnected::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BCQFullyConnected::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::BCQFullyConnected::WEIGHTS_BINARY, onert::ir::operation::BCQFullyConnected::WEIGHTS_CLUSTERS, and onert::ir::operation::BCQFullyConnected::WEIGHTS_SCALES.

◆ visit() [4/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BCQGather node)
override

Definition at line 137 of file ShapeValidator.cc.

138{
139 const auto &operands = _graph.operands();
140 const auto ofm_index{node.getOutputs().at(0)};
141 if (operands.at(ofm_index).info().isDynamic())
142 return;
143
144 const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
145 const auto input_binary_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
146 const auto input_scales_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_SCALES)};
147 const auto input_clusters_index{
149
150 OP_REQUIRES(operands.at(indices_index).shape().rank() <=
151 2); // TODO : support rank up to 4 or more
152 OP_REQUIRES(operands.at(input_binary_index).shape().rank() == 2);
153 OP_REQUIRES(operands.at(input_scales_index).shape().rank() == 1);
154 OP_REQUIRES(operands.at(input_clusters_index).shape().rank() == 2);
155
156 OP_REQUIRES(operands.at(input_clusters_index).shape().dim(0) > 0);
157 OP_REQUIRES(operands.at(input_clusters_index).shape().dim(1) == 2);
158
159 // more shape validation will be done inside kernel.
160}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BCQGather::INDICES, onert::ir::operation::BCQGather::INPUT_BINARY, onert::ir::operation::BCQGather::INPUT_CLUSTERS, onert::ir::operation::BCQGather::INPUT_SCALES, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [5/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::BroadcastTo node)
override

Definition at line 162 of file ShapeValidator.cc.

163{
164 const auto &operands = _graph.operands();
165 const auto output_index{node.getOutputs().at(0)};
166 if (operands.at(output_index).info().isDynamic())
167 return;
168
169 const auto input_index{node.getInputs().at(ir::operation::BroadcastTo::Input::INPUT)};
170 const auto shape_index{node.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
171 const auto &input_shape = operands.at(input_index).shape();
172 const auto &output_shape_vec = operands.at(shape_index).asVector<int32_t>();
173 int input_num_dims = input_shape.rank();
174 int output_num_dims = output_shape_vec.size();
175 OP_REQUIRES(input_num_dims <= output_num_dims);
176
177 int extending_dims = output_num_dims - input_num_dims;
178 for (int idx = 0; idx < input_num_dims; ++idx)
179 {
180 OP_REQUIRES(input_shape.dim(idx) == 1 ||
181 input_shape.dim(idx) == output_shape_vec.at(extending_dims + idx));
182 }
183}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BroadcastTo::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::BroadcastTo::SHAPE.

◆ visit() [6/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Comparison node)
override

Definition at line 185 of file ShapeValidator.cc.

186{
187 // TODO Shape validation of comparison
188}

◆ visit() [7/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Conv2D node)
override

Definition at line 190 of file ShapeValidator.cc.

191{
192 const auto &operands = _graph.operands();
193 const auto ofm_index{node.getOutputs().at(0)};
194 if (operands.at(ofm_index).info().isDynamic())
195 return;
196
197 const auto ifm_index{node.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
198 const auto ker_index{node.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
199 const auto bias_index{node.getInputs().at(ir::operation::Conv2D::Input::BIAS)};
200
201 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
202 OP_REQUIRES(operands.at(ker_index).shape().rank() == 4);
203 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
204 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
205}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Conv2D::INPUT, onert::ir::operation::Conv2D::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [8/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::DepthToSpace node)
override

Definition at line 207 of file ShapeValidator.cc.

208{
209 const auto &operands = _graph.operands();
210 int32_t block_size = node.param().block_size;
211
212 // shape check
213 const auto output_index{node.getOutputs().at(0)};
214 if (operands.at(output_index).info().isDynamic())
215 return;
216
217 const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
218
219 const auto output_shape = operands.at(output_index).shape().asFeature();
220 const auto input_shape = operands.at(input_index).shape().asFeature();
221
222 OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
223 OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
224
225 {
226 OP_REQUIRES(output_shape.N == input_shape.N);
227 OP_REQUIRES(output_shape.H == input_shape.H * block_size);
228 OP_REQUIRES(output_shape.W == input_shape.W * block_size);
229 OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
230 OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
231 }
232}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::DepthToSpace::Param::block_size, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthToSpace::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::DepthToSpace::param().

◆ visit() [9/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::DepthwiseConv2D node)
override

Definition at line 234 of file ShapeValidator.cc.

235{
236 const auto &operands = _graph.operands();
237 const auto ofm_index{node.getOutputs().at(0)};
238 if (operands.at(ofm_index).info().isDynamic())
239 return;
240
241 const auto ifm_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::INPUT)};
242 const auto ker_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::KERNEL)};
243 const auto bias_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::BIAS)};
244
245 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
246 OP_REQUIRES(operands.at(ker_index).shape().rank() == 4);
247 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
248 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
249}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::DepthwiseConv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthwiseConv2D::INPUT, onert::ir::operation::DepthwiseConv2D::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [10/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::DynamicUpdateSlice node)
override

Definition at line 251 of file ShapeValidator.cc.

252{
253 const auto &operands = _graph.operands();
254 const auto output_index{node.getOutputs().at(0)};
255 if (operands.at(output_index).info().isDynamic())
256 return;
257
258 const auto operand_index{node.getInputs().at(ir::operation::DynamicUpdateSlice::Input::OPERAND)};
259 const auto update_index{node.getInputs().at(ir::operation::DynamicUpdateSlice::Input::UPDATE)};
260 const auto indices_index{node.getInputs().at(ir::operation::DynamicUpdateSlice::Input::INDICES)};
261
262 OP_REQUIRES(operands.at(indices_index).shape().rank() == 1);
263 OP_REQUIRES(operands.at(indices_index).shape().dim(0) ==
264 operands.at(operand_index).shape().rank());
265 OP_REQUIRES(operands.at(operand_index).shape().rank() ==
266 operands.at(update_index).shape().rank());
267 for (int i = 0; i < operands.at(operand_index).shape().rank(); i++)
268 {
269 OP_REQUIRES(operands.at(operand_index).shape().dim(i) >=
270 operands.at(update_index).shape().dim(i));
271 }
272 OP_REQUIRES(operands.at(operand_index).shape() == operands.at(output_index).shape());
273}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DynamicUpdateSlice::INDICES, OP_REQUIRES, onert::ir::operation::DynamicUpdateSlice::OPERAND, onert::ir::Graph::operands(), and onert::ir::operation::DynamicUpdateSlice::UPDATE.

◆ visit() [11/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseActivation node)
override

Definition at line 275 of file ShapeValidator.cc.

275{ checkUnaryOp(node); }

◆ visit() [12/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseBinary node)
override

Definition at line 277 of file ShapeValidator.cc.

278{
279 // TODO Shape validation of ElementwiseBinary
280}

◆ visit() [13/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ElementwiseUnary node)
override

Definition at line 282 of file ShapeValidator.cc.

283{
284 const auto &operands = _graph.operands();
285 const auto output_index{node.getOutputs().at(0)};
286 const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
287
288 if (operands.at(output_index).info().isDynamic())
289 return;
290
291 OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
292}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ElementwiseUnary::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [14/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::EmbeddingLookup node)
override

Definition at line 294 of file ShapeValidator.cc.

295{
296 const auto &operands = _graph.operands();
297 const auto output_index{node.getOutputs().at(0)};
298 const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
299 const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
300
301 const auto &output_obj = operands.at(output_index);
302 const auto &lookups_obj = operands.at(lookups_index);
303 const auto &values_obj = operands.at(values_index);
304
305 // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
306 // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
307 {
308 if (operands.at(output_index).info().isDynamic())
309 return;
310
311 const auto &output_shape = output_obj.shape();
312 const auto &lookups_shape = lookups_obj.shape();
313 const auto &values_shape = values_obj.shape();
314
315 OP_REQUIRES(lookups_shape.rank() == 1);
316 OP_REQUIRES(values_shape.rank() >= 2);
317
318 // output should be a n-D tensor with the same rank and shape as the values tensor, except for
319 // the first dimension which has the same size as lookups' only dimension.
320 OP_REQUIRES(output_shape.rank() == values_shape.rank());
321 OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
322 for (int n = 1; n < output_shape.rank(); ++n)
323 {
324 OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
325 }
326 }
327}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::EmbeddingLookup::LOOKUPS, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::EmbeddingLookup::VALUES.

◆ visit() [15/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ExpandDims node)
override

Definition at line 329 of file ShapeValidator.cc.

330{
331 const auto &operands = _graph.operands();
332 const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
333
334 if (operands.at(axis_index).info().isDynamic())
335 return;
336 OP_REQUIRES(operands.at(axis_index).shape().rank() <= 1);
337}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::ExpandDims::AXIS, onert::ir::Operation::getInputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [16/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::FullyConnected node)
override

Definition at line 339 of file ShapeValidator.cc.

340{
341 const auto &operands = _graph.operands();
342 const auto ofm_index{node.getOutputs().at(0)};
343 if (operands.at(ofm_index).info().isDynamic())
344 return;
345
346 const auto ifm_index{node.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
347 const auto ker_index{node.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
348 const auto bias_index{node.getInputs().at(ir::operation::FullyConnected::Input::BIAS)};
349
350 OP_REQUIRES(operands.at(ifm_index).shape().rank() >= 2);
351 OP_REQUIRES(operands.at(ker_index).shape().rank() == 2);
352 OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
353}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::FullyConnected::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::FullyConnected::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::FullyConnected::WEIGHT.

◆ visit() [17/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Gather node)
override

Definition at line 355 of file ShapeValidator.cc.

356{
357 const auto &operands = _graph.operands();
358 const auto ofm_index{node.getOutputs().at(0)};
359 if (operands.at(ofm_index).info().isDynamic())
360 return;
361
362 const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
363 const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
364
365 const auto &ifm_shape = operands.at(ifm_index).shape();
366 const auto &indices_shape = operands.at(indices_index).shape();
367 const auto &ofm_shape = operands.at(ofm_index).shape();
368
369 // Since gather implementation is general enough, we do not restrict max rank
370 OP_REQUIRES(ifm_shape.rank() + indices_shape.rank() - 1 == ofm_shape.rank());
371}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Gather::INDICES, onert::ir::operation::Gather::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [18/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::HashtableLookup node)
override

Definition at line 373 of file ShapeValidator.cc.

374{
375 const auto &operands = _graph.operands();
376 const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
377 const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
378 const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
379 const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
380
381 const auto &output_obj = operands.at(output_index);
382 const auto &lookups_obj = operands.at(lookups_index);
383 const auto &keys_obj = operands.at(keys_index);
384 const auto &values_obj = operands.at(values_index);
385
386 if (operands.at(output_index).info().isDynamic())
387 return;
388
389 const auto &output_shape = output_obj.shape();
390 const auto &lookups_shape = lookups_obj.shape();
391 const auto &keys_shape = keys_obj.shape();
392 const auto &values_shape = values_obj.shape();
393
394 OP_REQUIRES(values_shape.rank() == output_shape.rank());
395 OP_REQUIRES(lookups_shape.rank() == 1);
396 OP_REQUIRES(keys_shape.rank() == 1);
397 OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
398 OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
399}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::HashtableLookup::KEYS, onert::ir::operation::HashtableLookup::LOOKUPS, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::HashtableLookup::OUTPUT, output_shape, and onert::ir::operation::HashtableLookup::VALUES.

◆ visit() [19/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::If node)
override

Definition at line 401 of file ShapeValidator.cc.

402{
403 // TODO Add to validate with subgraphs
404}

◆ visit() [20/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::InstanceNorm node)
override

Definition at line 406 of file ShapeValidator.cc.

407{
408 const auto &operands = _graph.operands();
409 const auto ofm_index{node.getOutputs().at(0)};
410 if (operands.at(ofm_index).info().isDynamic())
411 return;
412
413 const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
414 const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
415 const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
416
417 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
418 OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
419 OP_REQUIRES(operands.at(gamma_index).shape().rank() == 1);
420 OP_REQUIRES(operands.at(beta_index).shape().rank() == 1);
421}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::InstanceNorm::BETA, onert::ir::operation::InstanceNorm::GAMMA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::InstanceNorm::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [21/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::L2Normalization node)
override

Definition at line 423 of file ShapeValidator.cc.

424{
425 const auto &operands = _graph.operands();
426 const auto ofm_index{node.getOutputs().at(0)};
427 if (operands.at(ofm_index).info().isDynamic())
428 return;
429
430 const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
431
432 auto ifm_shape = operands.at(ifm_index).shape();
433 auto ofm_shape = operands.at(ofm_index).shape();
434
435 OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
436
437 for (auto i = 0; i < ifm_shape.rank(); i++)
438 {
439 OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
440 }
441}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::L2Normalization::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [22/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::LogSoftmax node)
override

Definition at line 443 of file ShapeValidator.cc.

444{
445 const auto &operands = _graph.operands();
446 const auto output_index{node.getOutputs().at(0)};
447 if (operands.at(output_index).info().isDynamic())
448 return;
449
450 const auto input_index{node.getInputs().at(0)};
451
452 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
453}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [23/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::LSTM node)
override

Definition at line 455 of file ShapeValidator.cc.

456{
457 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
458 // TODO Support dynamic rnn
459 const auto &operands = _graph.operands();
460 const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
461 if (operands.at(output_index).info().isDynamic())
462 return;
463
464 const auto scratch_buffer_index{
465 node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; // Optional
466 const auto output_state_out_index{
467 node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; // Optional
468 const auto cell_state_out_index{
469 node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; // Optional
470
471 const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
472 const auto input_to_input_weights_index{
473 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // Optional
474 const auto input_to_forget_weights_index{
476 const auto input_to_cell_weights_index{
478 const auto input_to_output_weights_index{
480 const auto recurrent_to_input_weights_index{
481 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // Optional
482 const auto recurrent_to_forget_weights_index{
484 const auto recurrent_to_cell_weights_index{
486 const auto recurrent_to_output_weights_index{
488 const auto cell_to_input_weights_index{
489 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // Optional
490 const auto cell_to_forget_weights_index{
491 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // Optional
492 const auto cell_to_output_weights_index{
493 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // Optional
494 const auto input_gate_bias_index{
495 node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; // Optional
496 const auto forget_gate_bias_index{
498 const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
499 const auto output_gate_bias_index{
501 const auto projection_weights_index{
502 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // Optional
503 const auto projection_bias_index{
504 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // Optional
505 const auto output_state_in_index{
507 const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
508
509 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
510 for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
511 {
512 OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
513 operands.at(output_index).shape().dim(i));
514 }
515 OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
516 operands.at(output_index).shape().rank() == 3) &&
517 (operands.at(input_index).shape().rank() == 2 ||
518 operands.at(input_index).shape().rank() == 3) &&
519 (!operands.exist(input_to_input_weights_index) ||
520 operands.at(input_to_input_weights_index).shape().rank() == 2) &&
521 operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
522 operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
523 operands.at(input_to_output_weights_index).shape().rank() == 2 &&
524 (!operands.exist(recurrent_to_input_weights_index) ||
525 operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
526 operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
527 operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
528 operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
529 (!operands.exist(projection_weights_index) ||
530 operands.at(projection_weights_index).shape().rank() == 2) &&
531 operands.at(output_state_in_index).shape().rank() == 2 &&
532 operands.at(cell_state_in_index).shape().rank() == 2);
533
534 OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
535 operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
536 (!operands.exist(cell_to_forget_weights_index) ||
537 operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
538 (!operands.exist(cell_to_output_weights_index) ||
539 operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
540 (!operands.exist(input_gate_bias_index) ||
541 operands.at(input_gate_bias_index).shape().rank() == 1) &&
542 operands.at(forget_gate_bias_index).shape().rank() == 1 &&
543 operands.at(cell_bias_index).shape().rank() == 1 &&
544 operands.at(output_gate_bias_index).shape().rank() == 1 &&
545 (!operands.exist(projection_bias_index) ||
546 operands.at(projection_bias_index).shape().rank() == 1));
547
548 // CIFG assertion
549 OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
550 (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
551 operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
552 (!operands.exist(recurrent_to_input_weights_index) ||
553 (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
554 operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
555 (!operands.exist(input_gate_bias_index) ||
556 operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
557 (!operands.exist(cell_to_input_weights_index) ||
558 operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
559 ((operands.exist(input_to_input_weights_index) &&
560 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
561 operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
562 (operands.exist(recurrent_to_input_weights_index) &&
563 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
564 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
565 (operands.exist(input_gate_bias_index) &&
566 operands.at(input_gate_bias_index).shape().dim(0) != 0)));
567
568 // Peephole assertion
569 OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
570 operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
571 (!operands.exist(cell_to_output_weights_index) ||
572 operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
573 ((operands.exist(cell_to_forget_weights_index) &&
574 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
575 (operands.exist(cell_to_output_weights_index) &&
576 operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
577
578 bool has_input_to_input_weights =
579 operands.exist(input_to_input_weights_index) &&
580 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
581 operands.at(input_to_input_weights_index).shape().dim(1) != 0);
582 bool has_recurrent_to_input_weights =
583 operands.exist(recurrent_to_input_weights_index) &&
584 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
585 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
586 bool has_input_gate_bias =
587 operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
588 bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
589 operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
590 bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
591 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
592 bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
593 operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
594 bool has_projection_weights = operands.exist(projection_weights_index) &&
595 (operands.at(projection_weights_index).shape().dim(0) != 0 &&
596 operands.at(projection_weights_index).shape().dim(1) != 0);
597 bool has_projection_bias =
598 operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
599
600 // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
601 // true: no CIFG
602 // false: CIFG
603 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
604
605 // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
606 // true: peephole
607 // false: no peephole
608 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
609
610 // NOTE The projection weights may have data but the projection bias may not.
611 bool has_projection_param = has_projection_weights;
612
613 const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
614 ? operands.at(input_index).shape().dim(1)
615 : operands.at(input_index).shape().dim(0);
616 OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
617 batch_size == operands.at(cell_state_in_index).shape().dim(0));
618
619 const auto input_size =
620 operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
621 OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
622 input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
623 input_size == operands.at(input_to_output_weights_index).shape().dim(1));
624
625 const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
626 OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
627 num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
628 num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
629 num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
630 num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
631 num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
632 num_units == operands.at(cell_bias_index).shape().dim(0) &&
633 num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
634 num_units == operands.at(cell_state_in_index).shape().dim(1));
635
636 const auto output_size =
637 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
638 OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
639 output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
640 output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
641 output_size == operands.at(output_state_in_index).shape().dim(1));
642
643 if (has_cifg_param)
644 {
645 OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
647 num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
648 num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
649 ((operands.exist(cell_to_input_weights_index) &&
650 num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
651 (!operands.exist(cell_to_input_weights_index) ||
652 operands.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
653 num_units == operands.at(input_gate_bias_index).shape().dim(0));
654 OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
655 OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
656 has_input_gate_bias);
657 if (has_cell_to_input_weights)
658 {
659 // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
660 OP_REQUIRES(has_peephole_param);
661 }
662 if (operands.exist(scratch_buffer_index))
663 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
664 }
665 else
666 {
667 if (operands.exist(scratch_buffer_index))
668 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
669 }
670
671 if (has_peephole_param)
672 {
673 OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
674 num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
675 (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
676 operands.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
677 }
678
679 if (has_projection_param)
680 {
681 OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
682 OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
683 if (has_projection_bias)
684 {
685 OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
686 }
687 }
688
689 if (operands.exist(scratch_buffer_index))
690 {
691 OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
692 OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
693 }
694
695 if (operands.exist(output_state_out_index))
696 {
697 OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
698 OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
699 OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
700 }
701
702 if (operands.exist(cell_state_out_index))
703 {
704 OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
705 OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
706 OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
707 }
708}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::LSTM::CELL_BIAS, onert::ir::operation::LSTM::CELL_STATE_IN, onert::ir::operation::LSTM::CELL_STATE_OUT, onert::ir::operation::LSTM::CELL_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::CELL_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::CELL_TO_OUTPUT_WEIGHTS, onert::ir::operation::LSTM::FORGET_GATE_BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::LSTM::INPUT, onert::ir::operation::LSTM::INPUT_GATE_BIAS, onert::ir::operation::LSTM::INPUT_TO_CELL_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::INPUT_TO_OUTPUT_WEIGHTS, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::LSTM::OUTPUT, onert::ir::operation::LSTM::OUTPUT_GATE_BIAS, onert::ir::operation::LSTM::OUTPUT_STATE_IN, onert::ir::operation::LSTM::OUTPUT_STATE_OUT, onert::ir::operation::LSTM::param(), onert::ir::operation::LSTM::PROJECTION_BIAS, onert::ir::operation::LSTM::PROJECTION_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_CELL_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_FORGET_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_INPUT_WEIGHTS, onert::ir::operation::LSTM::RECURRENT_TO_OUTPUT_WEIGHTS, onert::ir::operation::LSTM::SCRATCH_BUFFER, and onert::ir::operation::LSTM::Param::time_major.

◆ visit() [24/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pack node)
override

Definition at line 710 of file ShapeValidator.cc.

711{
712 const auto &operands = _graph.operands();
713 const auto axis{node.param().axis};
714 const auto output_index{node.getOutputs().at(0)};
715 if (operands.at(output_index).info().isDynamic())
716 return;
717
718 // shape check
719 const auto &output_shape = operands.at(output_index).shape();
720 const auto output_rank = static_cast<int32_t>(output_shape.rank());
721
722 const auto input1_index{node.getInputs().at(0)};
723 const auto &input_shape = operands.at(input1_index).shape();
724
725 OP_REQUIRES(axis >= -output_rank && axis < output_rank);
726 for (const auto &index : node.getInputs())
727 {
728 OP_REQUIRES(input_shape == operands.at(index).shape());
729 }
730}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Pack::Param::axis, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::Pack::param().

◆ visit() [25/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pad node)
override

Definition at line 732 of file ShapeValidator.cc.

733{
734 const auto &operands = _graph.operands();
735 const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
736 OP_REQUIRES(operands.at(pad_index).typeInfo().type() == ir::DataType::INT32);
737
738 const auto output_index{node.getInputs().at(0)};
739 if (operands.at(output_index).info().isDynamic())
740 return;
741
742 const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
743
744 const auto &pad_shape = operands.at(pad_index).shape();
745 const auto input_rank = static_cast<int32_t>(operands.at(input_index).shape().rank());
746
747 OP_REQUIRES(pad_shape.rank() == 2);
748 OP_REQUIRES(pad_shape.dim(0) == input_rank);
749 OP_REQUIRES(pad_shape.dim(1) == 2);
750 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
751}
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
uint32_t rank(void) const
Definition TensorShape.h:35
loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleNode *paddings)

References onert::ir::OperandIndexSequence::at(), loco::TensorShape::dim(), onert::ir::Operation::getInputs(), onert::ir::operation::Pad::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::Pad::PAD, and loco::TensorShape::rank().

◆ visit() [26/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Permute node)
override

Definition at line 753 of file ShapeValidator.cc.

754{
755 const auto &operands = _graph.operands();
756 const auto output_index{node.getOutputs().at(0)};
757 if (operands.at(output_index).info().isDynamic())
758 return;
759
760 const auto input_index{node.getInputs().at(0)};
761
762 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
763}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [27/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Pool2D node)
override

Definition at line 765 of file ShapeValidator.cc.

766{
767 const auto &operands = _graph.operands();
768 const auto ofm_index{node.getOutputs().at(0)};
769 if (operands.at(ofm_index).info().isDynamic())
770 return;
771
772 const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
773
774 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
775}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Pool2D::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [28/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Range node)
override

Definition at line 777 of file ShapeValidator.cc.

778{
779 const auto &operands = _graph.operands();
780 const auto output_index{node.getOutputs().at(0)};
781 const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
782 const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
783 const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
784
785 // Check for dimension constraints
786 if (operands.at(output_index).info().isDynamic())
787 return;
788
789 OP_REQUIRES(operands.at(start_index).shape().rank() == 0);
790 OP_REQUIRES(operands.at(limit_index).shape().rank() == 0);
791 OP_REQUIRES(operands.at(delta_index).shape().rank() == 0);
792}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Range::DELTA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Range::LIMIT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Range::START.

◆ visit() [29/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Reduce node)
override

Definition at line 794 of file ShapeValidator.cc.

795{
796 const auto &operands = _graph.operands();
797 const auto output_index{node.getOutputs().at(0)};
798 if (operands.at(output_index).info().isDynamic())
799 return;
800
801 const auto &input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
802 const auto &input_shape = operands.at(input_index).shape();
803 const auto &output_shape = operands.at(output_index).shape();
804
805 OP_REQUIRES(input_shape.rank() <= 4);
806 OP_REQUIRES(output_shape.rank() <= input_shape.rank());
807
808 // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
809 // supports cases reducing height and width or reducing depth.
810 // TODO We have to support all cases of dimensions up to 4.
811 // For correct permuting, we have to set output's shape to be equal in dimension position of the
812 // input. But the positions of the same dimensions in the input and output may be set differently.
813 // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
814 // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
815 // extend it in 4 dimensions, it should be {1,1,3,5}.
816 // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
817 // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
818 // next operation is not desired.
819 if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
820 {
821 if (output_shape.rank() == 2)
822 {
823 // Reducing HW
824 OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
825 input_shape.dim(3) == output_shape.dim(1));
826 }
827 else if (output_shape.rank() == 3)
828 {
829 // Reducing C or
830 // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
832 (input_shape.dim(0) == output_shape.dim(0) && input_shape.dim(1) == output_shape.dim(1) &&
833 input_shape.dim(2) == output_shape.dim(2)) ||
834 (input_shape.dim(0) == output_shape.dim(0) &&
835 (input_shape.dim(1) == output_shape.dim(1) || input_shape.dim(2) == output_shape.dim(1)) &&
836 input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
837 }
838 }
839}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reduce::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and output_shape.

◆ visit() [30/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::ResizeBilinear node)
override

Definition at line 841 of file ShapeValidator.cc.

842{
843 const auto &operands = _graph.operands();
844 const auto output_index{node.getOutputs().at(0)};
845 const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
846
847 if (operands.at(output_index).info().isDynamic())
848 {
849 return;
850 }
851 OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
852 OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
853}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ResizeBilinear::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [31/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Reverse node)
override

Definition at line 855 of file ShapeValidator.cc.

856{
857 const auto &operands = _graph.operands();
858 const auto output_index{node.getOutputs().at(0)};
859 const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
860
861 if (operands.at(output_index).info().isDynamic())
862 return;
863 OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
864}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reverse::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [32/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RmsNorm node)
override

Definition at line 866 of file ShapeValidator.cc.

867{
868 const auto &operands = _graph.operands();
869 const auto ofm_index{node.getOutputs().at(0)};
870 if (operands.at(ofm_index).info().isDynamic())
871 return;
872
873 const auto ifm_index{node.getInputs().at(ir::operation::RmsNorm::Input::INPUT)};
874 const auto gamma_index{node.getInputs().at(ir::operation::RmsNorm::Input::GAMMA)};
875
876 const auto &ifm_shape = operands.at(ifm_index).shape();
877 const auto &ofm_shape = operands.at(ofm_index).shape();
878 const auto &gamma_shape = operands.at(gamma_index).shape();
879
880 OP_REQUIRES(ifm_shape.rank() == 3 || ifm_shape.rank() == 4);
881 OP_REQUIRES(ifm_shape == ofm_shape);
882 OP_REQUIRES(gamma_shape.rank() == 1);
883 OP_REQUIRES((gamma_shape.dim(0) == 1) ||
884 (gamma_shape.dim(0) == ifm_shape.dim(ifm_shape.rank() - 1)));
885}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RmsNorm::GAMMA, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RmsNorm::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [33/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RNN node)
override

Definition at line 887 of file ShapeValidator.cc.

888{
889 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
890 // TODO Support dynamic rnn
891 const auto &operands = _graph.operands();
892 const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
893 if (operands.at(output_index).info().isDynamic())
894 return;
895
896 const auto hidden_state_out_index{
898
899 const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
900 const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
901 const auto recurrent_weights_index{
903 const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
904 const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
905
906 const auto batch_size = operands.at(output_index).shape().dim(0);
907 const auto num_units = operands.at(output_index).shape().dim(1);
908
909 OP_REQUIRES(operands.at(output_index).shape().rank() == 2 &&
910 operands.at(hidden_state_out_index).shape().rank() == 2 &&
911 operands.at(input_index).shape().rank() == 2 &&
912 operands.at(weights_index).shape().rank() == 2 &&
913 operands.at(recurrent_weights_index).shape().rank() == 2 &&
914 operands.at(hidden_state_in_index).shape().rank() == 2);
915 OP_REQUIRES(operands.at(bias_index).shape().rank() == 1);
916
917 OP_REQUIRES(batch_size == operands.at(input_index).shape().dim(0) &&
918 batch_size == operands.at(hidden_state_in_index).shape().dim(0) &&
919 batch_size == operands.at(hidden_state_out_index).shape().dim(0));
920 OP_REQUIRES(operands.at(input_index).shape().dim(1) == operands.at(weights_index).shape().dim(1));
921
922 OP_REQUIRES(num_units == operands.at(weights_index).shape().dim(0) &&
923 num_units == operands.at(recurrent_weights_index).shape().dim(0) &&
924 num_units == operands.at(bias_index).shape().dim(0));
925 OP_REQUIRES(num_units == operands.at(output_index).shape().dim(1) &&
926 num_units == operands.at(recurrent_weights_index).shape().dim(1) &&
927 num_units == operands.at(hidden_state_in_index).shape().dim(1) &&
928 num_units == operands.at(hidden_state_out_index).shape().dim(1));
929}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RNN::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RNN::HIDDEN_STATE_IN, onert::ir::operation::RNN::HIDDEN_STATE_OUT, onert::ir::operation::RNN::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), onert::ir::operation::RNN::OUTPUT, onert::ir::operation::RNN::RECURRENT_WEIGHTS, and onert::ir::operation::RNN::WEIGHTS.

◆ visit() [34/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::RoPE node)
override

Definition at line 931 of file ShapeValidator.cc.

932{
933 const auto &operands = _graph.operands();
934 const auto ofm_index{node.getOutputs().at(0)};
935 if (operands.at(ofm_index).info().isDynamic())
936 return;
937
938 const auto ifm_index{node.getInputs().at(ir::operation::RoPE::Input::INPUT)};
939 const auto sin_table_index{node.getInputs().at(ir::operation::RoPE::Input::SIN_TABLE)};
940 const auto cos_table_index{node.getInputs().at(ir::operation::RoPE::Input::COS_TABLE)};
941
942 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
943 OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
944 OP_REQUIRES(operands.at(sin_table_index).shape().rank() == 4);
945 OP_REQUIRES(operands.at(cos_table_index).shape().rank() == 4);
946}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RoPE::COS_TABLE, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RoPE::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::RoPE::SIN_TABLE.

◆ visit() [35/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Select node)
override

Definition at line 948 of file ShapeValidator.cc.

949{
950 // TODO Shape validation of select
951}

◆ visit() [36/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Shape node)
override

Definition at line 953 of file ShapeValidator.cc.

954{
955 const auto &operands = _graph.operands();
956 const auto output_index{node.getOutputs().at(0)};
957 if (operands.at(output_index).info().isDynamic())
958 return;
959
960 [[maybe_unused]] const auto input_index{node.getInputs().at(0)};
961 OP_REQUIRES(operands.at(output_index).shape().rank() == 1);
962}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [37/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Softmax node)
override

Definition at line 964 of file ShapeValidator.cc.

965{
966 const auto &operands = _graph.operands();
967 const auto output_index{node.getOutputs().at(0)};
968 if (operands.at(output_index).info().isDynamic())
969 return;
970
971 const auto input_index{node.getInputs().at(0)};
972
973 OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
974}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [38/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SpaceToBatchND node)
override

Definition at line 976 of file ShapeValidator.cc.

977{
978 const auto &operands = _graph.operands();
979 const auto ofm_index{node.getOutputs().at(0)};
980 if (operands.at(ofm_index).info().isDynamic())
981 return;
982
983 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
984 const auto block_size_index{
986 const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
987
988 const auto input_shape = operands.at(ifm_index).shape().asFeature();
989 const auto output_shape = operands.at(ofm_index).shape().asFeature();
990
991 // All requirement as per NNAPI specification.
992 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
993 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
994 OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
995 OP_REQUIRES(operands.at(paddings_index).shape().rank() == 2);
996
997 OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
998 OP_REQUIRES(operands.at(paddings_index).shape().dim(0) == 2);
999 OP_REQUIRES(operands.at(paddings_index).shape().dim(1) == 2);
1000
1001 OP_REQUIRES(input_shape.C == output_shape.C);
1002}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::SpaceToBatchND::BLOCK_SIZE, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SpaceToBatchND::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SpaceToBatchND::PADDINGS.

◆ visit() [39/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SpaceToDepth node)
override

Definition at line 1004 of file ShapeValidator.cc.

1005{
1006 const auto &operands = _graph.operands();
1007 const auto ofm_index{node.getOutputs().at(0)};
1008 if (operands.at(ofm_index).info().isDynamic())
1009 return;
1010
1011 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
1012
1013 const auto input_shape = operands.at(ifm_index).shape().asFeature();
1014 const auto output_shape = operands.at(ofm_index).shape().asFeature();
1015 const auto block_size = node.param().block_size;
1016
1017 // All assertions as per NNAPI specification.
1018 OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
1019 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
1020 OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0));
1021 OP_REQUIRES(input_shape.N == output_shape.N);
1022 OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
1023}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::SpaceToDepth::Param::block_size, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SpaceToDepth::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SpaceToDepth::param().

◆ visit() [40/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Split node)
override

Definition at line 1025 of file ShapeValidator.cc.

1026{
1027 const auto &operands = _graph.operands();
1028 const auto output_index{node.getOutputs().at(0)};
1029 if (operands.at(output_index).info().isDynamic())
1030 return;
1031
1032 const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
1033 const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)};
1034
1035 const auto num_splits = node.param().num_splits;
1036 const auto input_rank = operands.at(input_index).shape().rank();
1037 auto axis = *reinterpret_cast<const int32_t *>(operands.at(axis_index).data()->base());
1038 axis = axis < 0 ? axis + input_rank : axis;
1039
1040 OP_REQUIRES(axis >= 0 && axis < input_rank);
1041 OP_REQUIRES(operands.at(input_index).shape().dim(axis) % num_splits == 0);
1042}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Split::AXIS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Split::INPUT, onert::ir::operation::Split::Param::num_splits, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Split::param().

◆ visit() [41/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::SquaredDifference node)
override

Definition at line 1044 of file ShapeValidator.cc.

1045{
1046 const auto &operands = _graph.operands();
1047 const auto output_index{node.getOutputs().at(0)};
1048 const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
1049 const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
1050
1051 // Check for dimension constraints
1052 if (operands.at(output_index).info().isDynamic())
1053 return;
1054
1055 auto output_shape = operands.at(output_index).shape();
1056 auto lhs_shape = operands.at(lhs_index).shape();
1057 auto rhs_shape = operands.at(rhs_index).shape();
1058 // Check for output rank
1059 OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
1060 auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
1061
1062 for (int idx = 1; idx <= min_rank; idx++)
1063 {
1064 int l_idx = lhs_shape.rank() - idx;
1065 int r_idx = rhs_shape.rank() - idx;
1066 int out_idx = output_shape.rank() - idx;
1067
1068 OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
1069
1070 auto l_dims = lhs_shape.dim(l_idx);
1071 auto r_dims = rhs_shape.dim(r_idx);
1072 auto out_dims = output_shape.dim(out_idx);
1073
1074 OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
1075 ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
1076 }
1077 auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
1078 for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
1079 {
1080 int out_idx = output_shape.rank() - idx;
1081 int tmp_idx = tmp_shape.rank() - idx;
1082
1083 OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
1084 (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
1085 }
1086}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SquaredDifference::LHS, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::SquaredDifference::RHS.

◆ visit() [42/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::StridedSlice node)
override

Definition at line 1088 of file ShapeValidator.cc.

1089{
1090 const auto &operands = _graph.operands();
1091 const auto output_index{node.getOutputs().at(0)};
1092 const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
1093
1094 if (operands.at(output_index).info().isDynamic())
1095 return;
1096
1097 OP_REQUIRES(operands.at(input_index).shape().rank() <= 5);
1098}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::StridedSlice::INPUT, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [43/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Tile node)
override

Definition at line 1100 of file ShapeValidator.cc.

1101{
1102 const auto &operands = _graph.operands();
1103 const auto output_index{node.getOutputs().at(0)};
1104 if (operands.at(output_index).info().isDynamic())
1105 return;
1106
1107 const auto input_index{node.getInputs().at(0)};
1108 const auto multiple_index{node.getInputs().at(1)};
1109
1110 OP_REQUIRES(operands.at(multiple_index).shape().rank() == 1);
1111 OP_REQUIRES(operands.at(multiple_index).shape().dim(0) ==
1112 operands.at(input_index).shape().rank());
1113 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
1114}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [44/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Transpose node)
override

Definition at line 1116 of file ShapeValidator.cc.

1117{
1118 const auto &operands = _graph.operands();
1119 const auto output_index{node.getOutputs().at(0)};
1120 if (operands.at(output_index).info().isDynamic())
1121 return;
1122
1123 const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
1124 const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
1125
1126 const auto &output_shape = operands.at(output_index).shape();
1127 const auto &input_shape = operands.at(input_index).shape();
1128
1129 OP_REQUIRES(operands.at(perm_index).shape().num_elements() == 0 ||
1130 input_shape.rank() ==
1131 static_cast<int>(operands.at(perm_index).shape().num_elements()));
1132 OP_REQUIRES(input_shape.rank() == output_shape.rank());
1133}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Transpose::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), output_shape, and onert::ir::operation::Transpose::PERMUTATION.

◆ visit() [45/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::TransposeConv node)
override

Definition at line 1135 of file ShapeValidator.cc.

1136{
1137 // shape check
1138 const auto &operands = _graph.operands();
1139 const auto ofm_index{node.getOutputs().at(0)};
1140
1141 if (operands.at(ofm_index).info().isDynamic())
1142 return;
1143
1144 const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
1145 const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
1146
1147 // Only 4D tensors are supported
1148 OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
1149 OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ifm_index).shape().rank());
1150 OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ker_index).shape().rank());
1151
1152 const auto ofm_shape = operands.at(ofm_index).shape().asFeature();
1153 const auto ifm_shape = operands.at(ifm_index).shape().asFeature();
1154 // The kernel has only IHWO layout on frontend
1155 // So ker_shape is treated here below
1156 // I -> N
1157 // H -> H
1158 // W -> W
1159 // O -> C
1160 const auto ker_shape = operands.at(ker_index).shape().asFeature();
1161
1162 OP_REQUIRES(ifm_shape.N == ofm_shape.N);
1163 OP_REQUIRES(ifm_shape.C == ker_shape.C);
1164 OP_REQUIRES(ker_shape.N == ofm_shape.C);
1165}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::TransposeConv::INPUT, onert::ir::operation::TransposeConv::KERNEL, OP_REQUIRES, and onert::ir::Graph::operands().

◆ visit() [46/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::Unpack node)
override

Definition at line 1167 of file ShapeValidator.cc.

1168{
1169 const auto &operands = _graph.operands();
1170 const auto axis{node.param().axis};
1171 const auto output_index{node.getInputs().at(0)};
1172 if (operands.at(output_index).info().isDynamic())
1173 return;
1174
1175 const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
1176
1177 const auto &input_shape = operands.at(input_index).shape();
1178 const auto input_rank = static_cast<int32_t>(input_shape.rank());
1179
1180 OP_REQUIRES(axis >= -input_rank && axis < input_rank);
1181}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Unpack::Param::axis, onert::ir::Operation::getInputs(), onert::ir::operation::Unpack::INPUT, OP_REQUIRES, onert::ir::Graph::operands(), and onert::ir::operation::Unpack::param().

◆ visit() [47/47]

void onert::compiler::ShapeValidator::visit ( const ir::operation::While node)
override

Definition at line 1183 of file ShapeValidator.cc.

1184{
1185 // This validator does not check shape. So checking isDynamic() is skipped.
1186 // TODO Add to validate with subgraphs
1187}

The documentation for this class was generated from the following files: