ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::ir::OperationValidator Class Reference

#include <OperationValidator.h>

Collaboration diagram for onert::ir::OperationValidator:

Public Member Functions

 OperationValidator (void)=delete
 
 OperationValidator (const Graph &graph)
 
void operator() ()
 
void visit (const operation::AddN &node) override
 
void visit (const operation::ArgMinMax &node) override
 
void visit (const operation::Attention &node) override
 
void visit (const operation::BatchMatMul &node) override
 
void visit (const operation::BatchToSpaceND &node) override
 
void visit (const operation::BinaryArithmetic &node) override
 
void visit (const operation::BroadcastTo &node) override
 
void visit (const operation::Comparison &node) override
 
void visit (const operation::Concat &node) override
 
void visit (const operation::Conv2D &node) override
 
void visit (const operation::DepthToSpace &node) override
 
void visit (const operation::DepthwiseConv2D &node) override
 
void visit (const operation::DetectionPostProcess &node) override
 
void visit (const operation::DynamicUpdateSlice &node) override
 
void visit (const operation::ElementwiseActivation &node) override
 
void visit (const operation::ElementwiseBinary &node) override
 
void visit (const operation::ElementwiseUnary &node) override
 
void visit (const operation::EmbeddingLookup &node) override
 
void visit (const operation::ExpandDims &node) override
 
void visit (const operation::Fill &node) override
 
void visit (const operation::Gather &node) override
 
void visit (const operation::HashtableLookup &node) override
 
void visit (const operation::Pack &node) override
 
void visit (const operation::Pad &node) override
 
void visit (const operation::Rank &node) override
 
void visit (const operation::ResizeBilinear &node) override
 
void visit (const operation::Reverse &node) override
 
void visit (const operation::RoPE &node) override
 
void visit (const operation::Select &node) override
 
void visit (const operation::Shape &node) override
 
void visit (const operation::Slice &node) override
 
void visit (const operation::Softmax &node) override
 
void visit (const operation::SpaceToBatchND &node) override
 
void visit (const operation::SpaceToDepth &node) override
 
void visit (const operation::Split &node) override
 
void visit (const operation::SquaredDifference &node) override
 
void visit (const operation::StatelessRandomUniform &node) override
 
void visit (const operation::StridedSlice &node) override
 
void visit (const operation::TopKV2 &node) override
 
void visit (const operation::Transpose &node) override
 
void visit (const operation::TransposeConv &node) override
 
void visit (const operation::Unpack &node) override
 
void visit (const operation::While &node) override
 
- Public Member Functions inherited from onert::ir::OperationVisitor
virtual ~OperationVisitor ()=default
 

Detailed Description

Definition at line 33 of file OperationValidator.h.

Constructor & Destructor Documentation

◆ OperationValidator() [1/2]

onert::ir::OperationValidator::OperationValidator ( void  )
delete

◆ OperationValidator() [2/2]

onert::ir::OperationValidator::OperationValidator ( const Graph graph)

Definition at line 32 of file OperationValidator.cc.

33 : _operations{graph.operations()}, _operands{graph.operands()}
34{
35}

Member Function Documentation

◆ operator()()

void onert::ir::OperationValidator::operator() ( )

Definition at line 37 of file OperationValidator.cc.

38{
39 _operations.iterate([&](const OperationIndex &, const IOperation &node) { node.accept(*this); });
40}
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
::onert::util::Index< uint32_t, OperationIndexTag > OperationIndex
Definition Index.h:30

References onert::ir::IOperation::accept(), and onert::util::ObjectManager< Index, Object >::iterate().

◆ visit() [1/43]

void onert::ir::OperationValidator::visit ( const operation::AddN node)
override

Definition at line 86 of file OperationValidator.cc.

87{
88 const auto output_index(node.getOutputs().at(0));
89
90 int size = node.getInputs().size();
91 for (int i = 0; i < size; i++)
92 {
93 const auto input_index(node.getInputs().at(i));
94 OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32}));
95 OP_REQUIRES(isSameType(input_index, output_index));
96 }
97}
#define OP_REQUIRES(EXP)
int32_t size[5]
Definition Slice.cpp:35

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, onert::ir::OperandIndexSequence::size(), and size.

◆ visit() [2/43]

void onert::ir::OperationValidator::visit ( const operation::ArgMinMax node)
override

Definition at line 99 of file OperationValidator.cc.

100{
101 const auto input_index(node.getInputs().at(operation::ArgMinMax::Input::INPUT));
102 const auto axis_index(node.getInputs().at(operation::ArgMinMax::Input::AXIS));
103 const auto output_index(node.getOutputs().at(0));
104 const auto output_type = node.param().output_type;
105
106 OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::UINT8,
107 DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
108 OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
109 OP_REQUIRES(isValidType(output_index, {DataType::INT32, DataType::INT64}));
110 OP_REQUIRES(isValidType(output_index, output_type));
111}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::ArgMinMax::AXIS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ArgMinMax::INPUT, OP_REQUIRES, onert::ir::operation::ArgMinMax::Param::output_type, and onert::ir::operation::ArgMinMax::param().

◆ visit() [3/43]

void onert::ir::OperationValidator::visit ( const operation::Attention node)
override

Definition at line 113 of file OperationValidator.cc.

114{
115 const auto input_idx = node.getInputs().at(operation::Attention::Input::INPUT);
116 const auto &input_shape = _operands.at(input_idx).shape();
117
118 // Check if input's seq_len is 1
119 // Assuming input shape is [batch_size, seq_len, embedding_dim]
120 OP_REQUIRES(input_shape.rank() == 3);
121 OP_REQUIRES(input_shape.dim(1) == 1);
122
123 const auto cos_idx = node.getInputs().at(operation::Attention::Input::COS);
124 const auto sin_idx = node.getInputs().at(operation::Attention::Input::SIN);
125 const auto &cos_shape = _operands.at(cos_idx).shape();
126 const auto &sin_shape = _operands.at(sin_idx).shape();
127
128 // Check _cos and _sin shapes
129 // Assuming shape is [batch_size, seq_len, d_head]
130 // batch_size = 1, seq_len = 1
131 OP_REQUIRES(cos_shape.rank() == 3);
132 OP_REQUIRES(cos_shape.dim(0) == 1); // batch_size
133 OP_REQUIRES(cos_shape.dim(1) == 1); // seq_len
134
135 OP_REQUIRES(sin_shape.rank() == 3);
136 OP_REQUIRES(sin_shape.dim(0) == 1); // batch_size
137 OP_REQUIRES(sin_shape.dim(1) == 1); // seq_len
138
139 const auto pos_idx = node.getInputs().at(operation::Attention::Input::POS);
140 const auto &pos_shape = _operands.at(pos_idx).shape();
141
142 // Check pos tensor type and shape
143 OP_REQUIRES(isValidType(pos_idx, DataType::INT64));
144 OP_REQUIRES(pos_shape.rank() == 1);
145 OP_REQUIRES(pos_shape.dim(0) == 1);
146}
const Object & at(const Index &index) const
Get the object that is associated with the given index.

References onert::util::ObjectManager< Index, Object >::at(), onert::ir::OperandIndexSequence::at(), onert::ir::operation::Attention::COS, onert::ir::Operation::getInputs(), onert::ir::operation::Attention::INPUT, OP_REQUIRES, onert::ir::operation::Attention::POS, and onert::ir::operation::Attention::SIN.

◆ visit() [4/43]

void onert::ir::OperationValidator::visit ( const operation::BatchMatMul node)
override

Definition at line 148 of file OperationValidator.cc.

149{
150 const auto lhs_index(node.getInputs().at(operation::BatchMatMul::Input::LHS));
151 const auto rhs_index(node.getInputs().at(operation::BatchMatMul::Input::RHS));
152 const auto output_index(node.getOutputs().at(0));
153
154 // RHS can be constant, but LHS is not constant
155 // If one of inputs is constant, it must be RHS
156 // If two inputs are constant, BatchMatMul is optimized into constant by compiler
157 OP_REQUIRES(!isConstant(lhs_index));
158
159 // Allow hybrid quantization (lhs: float / rhs: qint8 / out: float)
160 OP_REQUIRES(isValidType(
161 lhs_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
162 OP_REQUIRES(isSameType(lhs_index, rhs_index) ||
163 ((operandType(lhs_index) == DataType::FLOAT32) &&
164 (operandType(rhs_index) == DataType::QUANT_INT8_ASYMM)));
165 OP_REQUIRES(isSameType(lhs_index, output_index));
166}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BatchMatMul::LHS, OP_REQUIRES, and onert::ir::operation::BatchMatMul::RHS.

◆ visit() [5/43]

void onert::ir::OperationValidator::visit ( const operation::BatchToSpaceND node)
override

Definition at line 168 of file OperationValidator.cc.

169{
170 const auto input_index{node.getInputs().at(operation::BatchToSpaceND::Input::INPUT)};
171 const auto output_index{node.getOutputs().at(0)};
172
173 OP_REQUIRES(isSameType(input_index, output_index));
174}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BatchToSpaceND::INPUT, and OP_REQUIRES.

◆ visit() [6/43]

void onert::ir::OperationValidator::visit ( const operation::BinaryArithmetic node)
override

Definition at line 176 of file OperationValidator.cc.

177{
178 const auto output_index{node.getOutputs().at(0)};
179 const auto lhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::LHS)};
180 const auto rhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::RHS)};
181
182 OP_REQUIRES(isSameType(lhs_index, rhs_index));
183 OP_REQUIRES(isSameType(lhs_index, output_index));
184}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BinaryArithmetic::LHS, OP_REQUIRES, and onert::ir::operation::BinaryArithmetic::RHS.

◆ visit() [7/43]

void onert::ir::OperationValidator::visit ( const operation::BroadcastTo node)
override

Definition at line 186 of file OperationValidator.cc.

187{
188 const auto input_index(node.getInputs().at(operation::BroadcastTo::Input::INPUT));
189 const auto shape_index(node.getInputs().at(operation::BroadcastTo::Input::SHAPE));
190 const auto output_index(node.getOutputs().at(0));
191
192 OP_REQUIRES(isSameType(input_index, output_index));
193 OP_REQUIRES(isValidType(shape_index, {DataType::INT32, DataType::INT64}));
194}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::BroadcastTo::INPUT, OP_REQUIRES, and onert::ir::operation::BroadcastTo::SHAPE.

◆ visit() [8/43]

void onert::ir::OperationValidator::visit ( const operation::Comparison node)
override

Definition at line 196 of file OperationValidator.cc.

197{
198 const auto output_index{node.getOutputs().at(0)};
199
200 const auto lhs_index{node.getInputs().at(operation::Comparison::Input::INPUT0)};
201 const auto rhs_index{node.getInputs().at(operation::Comparison::Input::INPUT1)};
202
203 OP_REQUIRES(isSameType(lhs_index, rhs_index));
204 OP_REQUIRES(isValidType(output_index, DataType::BOOL8));
205}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Comparison::INPUT0, onert::ir::operation::Comparison::INPUT1, and OP_REQUIRES.

◆ visit() [9/43]

void onert::ir::OperationValidator::visit ( const operation::Concat node)
override

Definition at line 207 of file OperationValidator.cc.

208{
209 const auto output_index{node.getOutputs().at(0)};
210
211 for (auto &&input_index : node.getInputs())
212 {
213 OP_REQUIRES(isSameType(input_index, output_index));
214
215 // Int8 and Int16 quantization requires same scale and zero point
216 if (isValidType(output_index, DataType::QUANT_INT8_ASYMM) ||
217 isValidType(output_index, DataType::QUANT_INT16_SYMM))
218 {
219 OP_REQUIRES(isSameQuantParam(input_index, output_index));
220 }
221 }
222}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), and OP_REQUIRES.

◆ visit() [10/43]

void onert::ir::OperationValidator::visit ( const operation::Conv2D node)
override

Definition at line 224 of file OperationValidator.cc.

225{
226 const auto input_index{node.getInputs().at(operation::Conv2D::Input::INPUT)};
227 const auto kernel_index{node.getInputs().at(operation::Conv2D::Input::KERNEL)};
228 const auto output_index{node.getOutputs().at(0)};
229
230 uint32_t stride_horizontal = node.param().stride.horizontal;
231 uint32_t stride_vertical = node.param().stride.vertical;
232 uint32_t dilation_width = node.param().dilation.width_factor;
233 uint32_t dilation_height = node.param().dilation.height_factor;
234
235 OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
236 OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
237 OP_REQUIRES(isSameType(input_index, output_index));
238
239 if (isConstant(kernel_index) && operandType(kernel_index) == DataType::QUANT_INT8_ASYMM)
240 {
241 for (const auto zeropoint : _operands.at(kernel_index).typeInfo().zero_points())
242 OP_REQUIRES(zeropoint == 0);
243 }
244}

References onert::util::ObjectManager< Index, Object >::at(), onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::Param::dilation, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::Dilation::height_factor, onert::ir::Stride::horizontal, onert::ir::operation::Conv2D::INPUT, onert::ir::operation::Conv2D::KERNEL, OP_REQUIRES, onert::ir::operation::Conv2D::param(), onert::ir::operation::Conv2D::Param::stride, onert::ir::Stride::vertical, and onert::ir::Dilation::width_factor.

◆ visit() [11/43]

void onert::ir::OperationValidator::visit ( const operation::DepthToSpace node)
override

Definition at line 246 of file OperationValidator.cc.

247{
248 const auto input_index{node.getInputs().at(operation::DepthToSpace::Input::INPUT)};
249 const auto output_index{node.getOutputs().at(0)};
250
251 int32_t block_size = node.param().block_size;
252
253 OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::INT64,
254 DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
255 OP_REQUIRES(isSameType(input_index, output_index));
256
257 OP_REQUIRES(block_size > 0);
258}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::DepthToSpace::Param::block_size, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthToSpace::INPUT, OP_REQUIRES, and onert::ir::operation::DepthToSpace::param().

◆ visit() [12/43]

void onert::ir::OperationValidator::visit ( const operation::DepthwiseConv2D node)
override

Definition at line 260 of file OperationValidator.cc.

261{
262 const auto input_index{node.getInputs().at(operation::DepthwiseConv2D::Input::INPUT)};
263 const auto kernel_index{node.getInputs().at(operation::DepthwiseConv2D::Input::KERNEL)};
264 const auto output_index{node.getOutputs().at(0)};
265
266 uint32_t stride_horizontal = node.param().stride.horizontal;
267 uint32_t stride_vertical = node.param().stride.vertical;
268 uint32_t dilation_width = node.param().dilation.width_factor;
269 uint32_t dilation_height = node.param().dilation.height_factor;
270
271 OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
272 OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
273 OP_REQUIRES(isSameType(input_index, output_index));
274
275 if (isConstant(kernel_index) && operandType(kernel_index) == DataType::QUANT_INT8_ASYMM)
276 {
277 for (const auto zeropoint : _operands.at(kernel_index).typeInfo().zero_points())
278 OP_REQUIRES(zeropoint == 0);
279 }
280}

References onert::util::ObjectManager< Index, Object >::at(), onert::ir::OperandIndexSequence::at(), onert::ir::operation::DepthwiseConv2D::Param::dilation, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::Dilation::height_factor, onert::ir::Stride::horizontal, onert::ir::operation::DepthwiseConv2D::INPUT, onert::ir::operation::DepthwiseConv2D::KERNEL, OP_REQUIRES, onert::ir::operation::DepthwiseConv2D::param(), onert::ir::operation::DepthwiseConv2D::Param::stride, onert::ir::Stride::vertical, and onert::ir::Dilation::width_factor.

◆ visit() [13/43]

void onert::ir::OperationValidator::visit ( const operation::DetectionPostProcess node)
override

Definition at line 282 of file OperationValidator.cc.

283{
284 const auto &param = node.param();
285
286 // FIXME: number of classes should be 1 for now.
287 OP_REQUIRES(param.num_classes == 1);
288}

References OP_REQUIRES, and onert::ir::operation::DetectionPostProcess::param().

◆ visit() [14/43]

void onert::ir::OperationValidator::visit ( const operation::DynamicUpdateSlice node)
override

Definition at line 290 of file OperationValidator.cc.

291{
292 const auto operand_index{node.getInputs().at(operation::DynamicUpdateSlice::Input::OPERAND)};
293 const auto update_index{node.getInputs().at(operation::DynamicUpdateSlice::Input::UPDATE)};
294 const auto indices_index{node.getInputs().at(operation::DynamicUpdateSlice::Input::INDICES)};
295 const auto output_index{node.getOutputs().at(0)};
296
297 OP_REQUIRES(isSameType(operand_index, update_index));
298 OP_REQUIRES(isSameType(operand_index, output_index));
299 OP_REQUIRES(operandType(indices_index) == DataType::INT32 ||
300 operandType(indices_index) == DataType::INT64);
301}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DynamicUpdateSlice::INDICES, OP_REQUIRES, onert::ir::operation::DynamicUpdateSlice::OPERAND, and onert::ir::operation::DynamicUpdateSlice::UPDATE.

◆ visit() [15/43]

void onert::ir::OperationValidator::visit ( const operation::ElementwiseActivation node)
override

Definition at line 303 of file OperationValidator.cc.

304{
305 const auto output_index{node.getOutputs().at(0)};
306 const auto input_index{node.getInputs().at(0)};
307
308 // Check if I/O types match
309 OP_REQUIRES(isSameType(output_index, input_index));
310
311 switch (node.param().op_type)
312 {
315 OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
316 break;
319 isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
320 DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_SYMM}));
321 break;
324 isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
325 DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_SYMM}));
326 break;
328 OP_REQUIRES(isValidType(
329 input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
330 break;
333 isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
334 DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_SYMM}));
335 break;
336 }
337}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::ElementwiseActivation::ELU, onert::ir::operation::ElementwiseActivation::GELU, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ElementwiseActivation::LEAKY_RELU, onert::ir::operation::ElementwiseActivation::LOGISTIC, OP_REQUIRES, onert::ir::operation::ElementwiseActivation::Param::op_type, onert::ir::operation::ElementwiseActivation::param(), onert::ir::operation::ElementwiseActivation::RELU, and onert::ir::operation::ElementwiseActivation::TANH.

◆ visit() [16/43]

void onert::ir::OperationValidator::visit ( const operation::ElementwiseBinary node)
override

◆ visit() [17/43]

void onert::ir::OperationValidator::visit ( const operation::ElementwiseUnary node)
override

Definition at line 356 of file OperationValidator.cc.

357{
358 const auto output_index{node.getOutputs().at(0)};
359 const auto input_index{node.getInputs().at(operation::ElementwiseUnary::Input::INPUT)};
360
361 // Check if I/O types match
362 if (node.param().op_type == operation::ElementwiseUnary::Type::DEQUANTIZE)
363 {
364 // NNAPI allow QUANT_INT8_SYMM type input
365 OP_REQUIRES(isValidType(input_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_SYMM,
366 DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_SYMM}));
367 OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
368 }
369 else if (node.param().op_type == operation::ElementwiseUnary::Type::QUANTIZE)
370 {
371 OP_REQUIRES(isValidType(
372 input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
373 OP_REQUIRES(isValidType(output_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM,
374 DataType::QUANT_INT16_SYMM}));
375 }
376 else if (node.param().op_type == operation::ElementwiseUnary::Type::FLOOR)
377 {
378 OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
379 OP_REQUIRES(isSameType(output_index, input_index));
380 }
381 else if (node.param().op_type != operation::ElementwiseUnary::Type::CAST)
382 {
383 OP_REQUIRES(isSameType(output_index, input_index));
384 }
385}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::ElementwiseUnary::CAST, onert::ir::operation::ElementwiseUnary::DEQUANTIZE, onert::ir::operation::ElementwiseUnary::FLOOR, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ElementwiseUnary::INPUT, OP_REQUIRES, onert::ir::operation::ElementwiseUnary::Param::op_type, onert::ir::operation::ElementwiseUnary::param(), and onert::ir::operation::ElementwiseUnary::QUANTIZE.

◆ visit() [18/43]

void onert::ir::OperationValidator::visit ( const operation::EmbeddingLookup node)
override

Definition at line 387 of file OperationValidator.cc.

388{
389 const auto lookups_index{node.getInputs().at(operation::EmbeddingLookup::Input::LOOKUPS)};
390 const auto values_index{node.getInputs().at(operation::EmbeddingLookup::Input::VALUES)};
391 const auto output_index{node.getOutputs().at(0)};
392
393 OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
394
395 // TFLite: Allow hybrid type - value table & output
396 // NNAPI: Require same value table and output type
398 isSameType(values_index, output_index) ||
399 (isValidType(output_index, DataType::FLOAT32) &&
400 (isValidType(values_index, {DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT8_SYMM}))));
401}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::EmbeddingLookup::LOOKUPS, OP_REQUIRES, and onert::ir::operation::EmbeddingLookup::VALUES.

◆ visit() [19/43]

void onert::ir::OperationValidator::visit ( const operation::ExpandDims node)
override

Definition at line 403 of file OperationValidator.cc.

404{
405 const auto output_index{node.getOutputs().at(0)};
406 const auto input_index{node.getInputs().at(operation::ExpandDims::Input::INPUT)};
407 const auto axis_index{node.getInputs().at(operation::ExpandDims::Input::AXIS)};
408
409 OP_REQUIRES(isSameType(output_index, input_index));
410 OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
411}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::ExpandDims::AXIS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ExpandDims::INPUT, and OP_REQUIRES.

◆ visit() [20/43]

void onert::ir::OperationValidator::visit ( const operation::Fill node)
override

Definition at line 413 of file OperationValidator.cc.

414{
415 const auto output_index{node.getOutputs().at(0)};
416 const auto input_index{node.getInputs().at(operation::Fill::Input::SHAPE)};
417 const auto value_index{node.getInputs().at(operation::Fill::Input::VALUE)};
418
419 OP_REQUIRES(isSameType(output_index, value_index));
420 OP_REQUIRES(isValidType(input_index, {DataType::INT32, DataType::INT64}));
421 OP_REQUIRES(isValidType(output_index,
422 {DataType::FLOAT32, DataType::INT32, DataType::INT64, DataType::BOOL8}));
423}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, onert::ir::operation::Fill::SHAPE, and onert::ir::operation::Fill::VALUE.

◆ visit() [21/43]

void onert::ir::OperationValidator::visit ( const operation::Gather node)
override

Definition at line 425 of file OperationValidator.cc.

426{
427 const auto output_index{node.getOutputs().at(0)};
428 const auto input_index{node.getInputs().at(operation::Gather::INPUT)};
429 const auto indices_index{node.getInputs().at(operation::Gather::INDICES)};
430
431 const auto input_type = operandType(input_index);
432 if (input_type == DataType::QUANT_GGML_Q4_0 || input_type == DataType::QUANT_GGML_Q8_0)
433 OP_REQUIRES(isValidType(output_index, {DataType::FLOAT32}));
434 else
435 OP_REQUIRES(isSameType(output_index, input_index));
436
437 OP_REQUIRES(isValidType(indices_index, {DataType::INT32, DataType::INT64}));
438}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Gather::INDICES, onert::ir::operation::Gather::INPUT, and OP_REQUIRES.

◆ visit() [22/43]

void onert::ir::OperationValidator::visit ( const operation::HashtableLookup node)
override

Definition at line 440 of file OperationValidator.cc.

441{
442 const auto hits_index{node.getOutputs().at(operation::HashtableLookup::Output::HITS)};
443 const auto lookups_index{node.getInputs().at(operation::HashtableLookup::Input::LOOKUPS)};
444 const auto keys_index{node.getInputs().at(operation::HashtableLookup::Input::KEYS)};
445
446 OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
447 OP_REQUIRES(isValidType(keys_index, DataType::INT32));
448 OP_REQUIRES(isValidType(hits_index, DataType::QUANT_UINT8_ASYMM));
449}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::HashtableLookup::HITS, onert::ir::operation::HashtableLookup::KEYS, onert::ir::operation::HashtableLookup::LOOKUPS, and OP_REQUIRES.

◆ visit() [23/43]

void onert::ir::OperationValidator::visit ( const operation::Pack node)
override

Definition at line 451 of file OperationValidator.cc.

452{
453 const auto num{node.param().num};
454
455 OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size()));
456}

References onert::ir::Operation::getInputs(), onert::ir::operation::Pack::Param::num, OP_REQUIRES, onert::ir::operation::Pack::param(), and onert::ir::OperandIndexSequence::size().

◆ visit() [24/43]

void onert::ir::OperationValidator::visit ( const operation::Pad node)
override

Definition at line 458 of file OperationValidator.cc.

459{
460 const auto output_index{node.getOutputs().at(0)};
461 const auto input_index{node.getInputs().at(operation::Pad::Input::INPUT)};
462 const auto pad_index{node.getInputs().at(operation::Pad::Input::PAD)};
463 bool isQuantType =
464 isValidType(output_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM});
465 bool isPadV2 = node.getInputs().size() == 3 ? true : false;
466
467 OP_REQUIRES(isValidType(pad_index, DataType::INT32));
468 OP_REQUIRES(isSameType(input_index, output_index));
469
470 if (isQuantType)
471 OP_REQUIRES(isSameQuantParam(input_index, output_index));
472
473 if (isPadV2)
474 {
475 const auto value_index{node.getInputs().at(operation::Pad::Input::VALUE)};
476 const bool cond_same = isSameType(input_index, value_index);
477 const bool cond_same_quant = (!isQuantType || isSameQuantParam(input_index, value_index));
478 const auto input_t = operandType(input_index);
479 const auto value_t = operandType(value_index);
480 // NNAPI accepts this case. scale and zeroPoint are assumed to be the same as in input0.
481 const bool cond_quant8 =
482 ((input_t == DataType::QUANT_UINT8_ASYMM || input_t == DataType::QUANT_INT8_ASYMM) &&
483 value_t == DataType::INT32);
484 OP_REQUIRES((cond_same && cond_same_quant) || cond_quant8);
485 }
486}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Pad::INPUT, OP_REQUIRES, onert::ir::operation::Pad::PAD, onert::ir::OperandIndexSequence::size(), and onert::ir::operation::Pad::VALUE.

◆ visit() [25/43]

void onert::ir::OperationValidator::visit ( const operation::Rank node)
override

Definition at line 488 of file OperationValidator.cc.

489{
490 const auto output_index{node.getOutputs().at(0)};
491
492 OP_REQUIRES(isValidType(output_index, DataType::INT32));
493}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getOutputs(), and OP_REQUIRES.

◆ visit() [26/43]

void onert::ir::OperationValidator::visit ( const operation::ResizeBilinear node)
override

Definition at line 495 of file OperationValidator.cc.

496{
497 auto align_corners = node.param().align_corners;
498 auto half_pixel_centers = node.param().half_pixel_centers;
499
500 OP_REQUIRES(!align_corners || !half_pixel_centers);
501}

References onert::ir::operation::ResizeBilinear::Param::align_corners, onert::ir::operation::ResizeBilinear::Param::half_pixel_centers, OP_REQUIRES, and onert::ir::operation::ResizeBilinear::param().

◆ visit() [27/43]

void onert::ir::OperationValidator::visit ( const operation::Reverse node)
override

Definition at line 503 of file OperationValidator.cc.

504{
505 const auto output_index{node.getOutputs().at(0)};
506 const auto input_index{node.getInputs().at(operation::Reverse::Input::INPUT)};
507 const auto axis_index{node.getInputs().at(operation::Reverse::Input::AXIS)};
508
509 OP_REQUIRES(isValidType(axis_index, DataType::INT32));
510 OP_REQUIRES(isSameType(output_index, input_index));
511}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Reverse::AXIS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reverse::INPUT, and OP_REQUIRES.

◆ visit() [28/43]

void onert::ir::OperationValidator::visit ( const operation::RoPE node)
override

Definition at line 513 of file OperationValidator.cc.

514{
515 const auto input_index{node.getInputs().at(operation::RoPE::Input::INPUT)};
516 const auto sin_index{node.getInputs().at(operation::RoPE::Input::SIN_TABLE)};
517 const auto cos_index{node.getInputs().at(operation::RoPE::Input::COS_TABLE)};
518 const auto output_index{node.getOutputs().at(operation::RoPE::Output::OUTPUT)};
519
520 OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
521 OP_REQUIRES(isValidType(sin_index, DataType::FLOAT32));
522 OP_REQUIRES(isValidType(cos_index, DataType::FLOAT32));
523 OP_REQUIRES(isSameType(input_index, output_index));
524}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::RoPE::COS_TABLE, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::RoPE::INPUT, OP_REQUIRES, onert::ir::operation::RoPE::OUTPUT, and onert::ir::operation::RoPE::SIN_TABLE.

◆ visit() [29/43]

void onert::ir::OperationValidator::visit ( const operation::Select node)
override

Definition at line 526 of file OperationValidator.cc.

527{
528 const auto condition_index{node.getInputs().at(operation::Select::Input::CONDITION)};
529 const auto input_true_index{node.getInputs().at(operation::Select::Input::INPUT_TRUE)};
530 const auto input_false_index{node.getInputs().at(operation::Select::Input::INPUT_FALSE)};
531
532 OP_REQUIRES(isValidType(condition_index, DataType::BOOL8));
533 OP_REQUIRES(isSameType(input_true_index, input_false_index));
534}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Select::CONDITION, onert::ir::Operation::getInputs(), onert::ir::operation::Select::INPUT_FALSE, onert::ir::operation::Select::INPUT_TRUE, and OP_REQUIRES.

◆ visit() [30/43]

void onert::ir::OperationValidator::visit ( const operation::Shape node)
override

Definition at line 536 of file OperationValidator.cc.

537{
538 const auto output_index{node.getOutputs().at(0)};
539
540 OP_REQUIRES(isValidType(output_index, {DataType::UINT32, DataType::INT32, DataType::INT64}));
541}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getOutputs(), and OP_REQUIRES.

◆ visit() [31/43]

void onert::ir::OperationValidator::visit ( const operation::Slice node)
override

Definition at line 543 of file OperationValidator.cc.

544{
545 const auto begins_index{node.getInputs().at(operation::Slice::BEGINS)};
546 const auto sizes_index{node.getInputs().at(operation::Slice::SIZES)};
547
548 OP_REQUIRES(isValidType(begins_index, {DataType::INT32, DataType::INT64}));
549 OP_REQUIRES(isSameType(begins_index, sizes_index));
550}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::Slice::BEGINS, onert::ir::Operation::getInputs(), OP_REQUIRES, and onert::ir::operation::Slice::SIZES.

◆ visit() [32/43]

void onert::ir::OperationValidator::visit ( const operation::Softmax node)
override

Definition at line 552 of file OperationValidator.cc.

553{
554 const auto output_index{node.getOutputs().at(0)};
555 const auto input_index{node.getInputs().at(operation::Softmax::INPUT)};
556
557 OP_REQUIRES(isSameType(input_index, output_index));
558 OP_REQUIRES(isValidType(
559 output_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
560}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Softmax::INPUT, and OP_REQUIRES.

◆ visit() [33/43]

void onert::ir::OperationValidator::visit ( const operation::SpaceToBatchND node)
override

Definition at line 562 of file OperationValidator.cc.

563{
564 const auto block_size_index{node.getInputs().at(operation::SpaceToBatchND::Input::BLOCK_SIZE)};
565 const auto paddings_index{node.getInputs().at(operation::SpaceToBatchND::Input::PADDINGS)};
566
567 // Non-constant block_size and padding is not implemented yet
568 OP_REQUIRES(isConstant(block_size_index));
569 OP_REQUIRES(isConstant(paddings_index));
570}

References onert::ir::OperandIndexSequence::at(), onert::ir::operation::SpaceToBatchND::BLOCK_SIZE, onert::ir::Operation::getInputs(), OP_REQUIRES, and onert::ir::operation::SpaceToBatchND::PADDINGS.

◆ visit() [34/43]

void onert::ir::OperationValidator::visit ( const operation::SpaceToDepth node)
override

Definition at line 572 of file OperationValidator.cc.

573{
574 const auto block_size = node.param().block_size;
575 OP_REQUIRES(block_size >= 1);
576}

References onert::ir::operation::SpaceToDepth::Param::block_size, OP_REQUIRES, and onert::ir::operation::SpaceToDepth::param().

◆ visit() [35/43]

void onert::ir::OperationValidator::visit ( const operation::Split node)
override

Definition at line 578 of file OperationValidator.cc.

579{
580 const auto num_splits = node.param().num_splits;
581
582 OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF);
583 OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits));
584}

References onert::ir::Operation::getOutputs(), onert::ir::operation::Split::Param::num_splits, OP_REQUIRES, onert::ir::operation::Split::param(), and onert::ir::OperandIndexSequence::size().

◆ visit() [36/43]

void onert::ir::OperationValidator::visit ( const operation::SquaredDifference node)
override

Definition at line 586 of file OperationValidator.cc.

587{
588 const auto output_index{node.getOutputs().at(0)};
589 const auto lhs_index{node.getInputs().at(operation::SquaredDifference::Input::LHS)};
590 const auto rhs_index{node.getInputs().at(operation::SquaredDifference::Input::RHS)};
591
592 OP_REQUIRES(isSameType(output_index, lhs_index));
593 OP_REQUIRES(isSameType(lhs_index, rhs_index));
594}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::SquaredDifference::LHS, OP_REQUIRES, and onert::ir::operation::SquaredDifference::RHS.

◆ visit() [37/43]

void onert::ir::OperationValidator::visit ( const operation::StatelessRandomUniform node)
override

Definition at line 596 of file OperationValidator.cc.

597{
598 const auto output_index{node.getOutputs().at(0)};
599 const auto shape_index{node.getInputs().at(operation::StatelessRandomUniform::Input::SHAPE)};
600 const auto seed_index{node.getInputs().at(operation::StatelessRandomUniform::Input::SEED)};
601
602 OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
603 OP_REQUIRES(isValidType(shape_index, DataType::INT32));
604 OP_REQUIRES(isValidType(seed_index, DataType::INT32));
605}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, onert::ir::operation::StatelessRandomUniform::SEED, and onert::ir::operation::StatelessRandomUniform::SHAPE.

◆ visit() [38/43]

void onert::ir::OperationValidator::visit ( const operation::StridedSlice node)
override

Definition at line 607 of file OperationValidator.cc.

608{
609 const auto output_index{node.getOutputs().at(0)};
610 const auto input_index{node.getInputs().at(operation::StridedSlice::Input::INPUT)};
611
612 OP_REQUIRES(isSameType(output_index, input_index));
613
614 if (isValidType(output_index, DataType::QUANT_INT16_SYMM))
615 {
616 OP_REQUIRES(isSameQuantParam(input_index, output_index));
617 }
618}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::StridedSlice::INPUT, and OP_REQUIRES.

◆ visit() [39/43]

void onert::ir::OperationValidator::visit ( const operation::TopKV2 node)
override

Definition at line 620 of file OperationValidator.cc.

621{
622 const auto output_value_index{node.getOutputs().at(0)};
623 const auto input_index{node.getInputs().at(operation::TopKV2::Input::INPUT)};
624
625 OP_REQUIRES(isSameType(output_value_index, input_index));
626 OP_REQUIRES(node.param().k >= 0);
627}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::TopKV2::INPUT, onert::ir::operation::TopKV2::Param::k, OP_REQUIRES, and onert::ir::operation::TopKV2::param().

◆ visit() [40/43]

void onert::ir::OperationValidator::visit ( const operation::Transpose node)
override

Definition at line 629 of file OperationValidator.cc.

630{
631 const auto output_index{node.getOutputs().at(0)};
632 const auto input_index{node.getInputs().at(operation::Transpose::Input::INPUT)};
633
634 OP_REQUIRES(isSameType(output_index, input_index));
635}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Transpose::INPUT, and OP_REQUIRES.

◆ visit() [41/43]

void onert::ir::OperationValidator::visit ( const operation::TransposeConv node)
override

◆ visit() [42/43]

void onert::ir::OperationValidator::visit ( const operation::Unpack node)
override

Definition at line 643 of file OperationValidator.cc.

644{
645 const auto num{node.param().num};
646 OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size()));
647}

References onert::ir::Operation::getOutputs(), onert::ir::operation::Unpack::Param::num, OP_REQUIRES, onert::ir::operation::Unpack::param(), and onert::ir::OperandIndexSequence::size().

◆ visit() [43/43]

void onert::ir::OperationValidator::visit ( const operation::While node)
override

Definition at line 649 of file OperationValidator.cc.

650{
651 OP_REQUIRES(node.getInputs().size() == node.getOutputs().size());
652}

References onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), OP_REQUIRES, and onert::ir::OperandIndexSequence::size().


The documentation for this class was generated from the following files: