ONE - On-device Neural Engine
Loading...
Searching...
No Matches
loco::MultiDialectShapeInferenceRule Class Referencefinal

Shape inference rule for multiple dialects. More...

#include <MultiDialectShapeInferenceRule.h>

Collaboration diagram for loco::MultiDialectShapeInferenceRule:

Public Member Functions

bool recognize (const Dialect *) const final
 Return true if this rule recognizes a given dialect.
 
bool infer (const Node *, NodeShape &) const final
 Infer node's shape.
 
MultiDialectShapeInferenceRulebind (const Dialect *d, const ShapeInferenceRule *rule)
 Bind a specific rule to a Dialect.
 
- Public Member Functions inherited from loco::ShapeInferenceRule
virtual ~ShapeInferenceRule ()=default
 
virtual bool support (const API &api) const
 Check whether a given API is available or not.
 
virtual void infer (const Context *, const Node *, Sink *) const
 

Additional Inherited Members

- Public Types inherited from loco::ShapeInferenceRule
enum class  API { V1 , V2 }
 

Detailed Description

Shape inference rule for multiple dialects.

Definition at line 30 of file MultiDialectShapeInferenceRule.h.

Member Function Documentation

◆ bind()

MultiDialectShapeInferenceRule & loco::MultiDialectShapeInferenceRule::bind ( const Dialect d,
const ShapeInferenceRule rule 
)

Bind a specific rule to a Dialect.

Definition at line 56 of file MultiDialectShapeInferenceRule.cpp.

58{
59 assert(_rules.find(d) == _rules.end());
60 assert(rule->recognize(d));
61
62 _rules[d] = rule;
63
64 return (*this);
65}

References loco::ShapeInferenceRule::recognize().

Referenced by exo::ShapeInferencePass::run(), and moco::tf::ShapeInferencePass::run().

◆ infer()

bool loco::MultiDialectShapeInferenceRule::infer ( const Node ,
NodeShape  
) const
finalvirtual

Infer node's shape.

WARNING!!

Implementation SHOULD return true only when it succeeds in inference!

Implements loco::ShapeInferenceRule.

Definition at line 42 of file MultiDialectShapeInferenceRule.cpp.

43{
44 const auto found = _rules.find(node->dialect());
45
46 if (found == _rules.cend())
47 return false;
48
49 auto rule = found->second;
50 if (rule->infer(node, shape))
51 return true;
52
53 return false;
54}

References loco::Node::dialect().

◆ recognize()

bool loco::MultiDialectShapeInferenceRule::recognize ( const Dialect ) const
finalvirtual

Return true if this rule recognizes a given dialect.

Implements loco::ShapeInferenceRule.

Definition at line 29 of file MultiDialectShapeInferenceRule.cpp.

30{
31 const auto found = _rules.find(d);
32
33 if (found == _rules.cend())
34 return false;
35
36 auto rule = found->second;
37 auto result = rule->recognize(d);
38
39 return result;
40}
result
Definition infer.py:103

The documentation for this class was generated from the following files: