ONE - On-device Neural Engine
Loading...
Searching...
No Matches
loco::ShapeInferenceSession Class Reference

#include <ShapeInference.h>

Public Member Functions

 ShapeInferenceSession (const ShapeInferenceRule *rule)
 
bool to (Graph *g) const
 

Detailed Description

Definition at line 36 of file ShapeInference.h.

Constructor & Destructor Documentation

◆ ShapeInferenceSession()

loco::ShapeInferenceSession::ShapeInferenceSession ( const ShapeInferenceRule rule)
inline

Definition at line 39 of file ShapeInference.h.

39 : _rule{rule}
40 {
41 // DO NOTHING
42 }

Member Function Documentation

◆ to()

bool loco::ShapeInferenceSession::to ( Graph g) const

Definition at line 68 of file ShapeInference.cpp.

69{
70 assert(_rule->support(ShapeInferenceRule::API::V1) && "API v1 is unavailable");
71
72 bool changed = false;
73
74 for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
75 {
76 if (_rule->recognize(node->dialect()))
77 {
78 loco::NodeShape shape;
79
80 if (!shape_known(node) && inputs_shape_ready(node))
81 {
82 if (_rule->infer(node, shape))
83 {
84 node->annot(std::make_unique<ShapeAnnotation>(shape));
85 changed = true;
86 }
87 }
88 }
89 }
90
91 return changed;
92}
std::vector< loco::Node * > postorder_traversal(const std::vector< loco::Node * > &roots)
Generate postorder traversal sequence starting from "roots".
Definition Algorithm.cpp:53
bool shape_known(const Node *node)
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
virtual bool support(const API &api) const
Check whether a given API is available or not.
virtual bool infer(const Node *, NodeShape &) const =0
Infer node's shape.
virtual bool recognize(const Dialect *) const =0
Return true if this rule recognizes a given dialect.

References loco::AnnotatedItem< Annotation >::annot(), loco::Node::dialect(), loco::ShapeInferenceRule::infer(), loco::output_nodes(), loco::postorder_traversal(), loco::ShapeInferenceRule::recognize(), loco::shape_known(), loco::ShapeInferenceRule::support(), and loco::ShapeInferenceRule::V1.

Referenced by exo::convert_to_TFLNodes(), exo::ShapeInferencePass::run(), exo::TypeInferencePass::run(), moco::tf::ShapeInferencePass::run(), and moco::tf::TypeInferencePass::run().


The documentation for this class was generated from the following files: