ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::TFShapeInferenceRule Struct Referencefinal

Shape inference rule for TensorFlow dialect. More...

#include <TFShapeInferenceRule.h>

Collaboration diagram for moco::TFShapeInferenceRule:

Public Member Functions

bool support (const API &ver) const final
 Check whether a given API is available or not.
 
bool recognize (const loco::Dialect *) const final
 Return true if this rule recognizes a given dialect.
 
bool infer (const loco::Node *, loco::NodeShape &) const final
 Infer node's shape.
 
void infer (const Context *, const loco::Node *, Sink *) const final
 
- Public Member Functions inherited from loco::ShapeInferenceRule
virtual ~ShapeInferenceRule ()=default
 

Additional Inherited Members

- Public Types inherited from loco::ShapeInferenceRule
enum class  API { V1 , V2 }
 

Detailed Description

Shape inference rule for TensorFlow dialect.

Definition at line 28 of file TFShapeInferenceRule.h.

Member Function Documentation

◆ infer() [1/2]

void moco::TFShapeInferenceRule::infer ( const Context ctx,
const loco::Node node,
Sink sink 
) const
finalvirtual

Reimplemented from loco::ShapeInferenceRule.

Definition at line 876 of file TFShapeInferenceRule.cpp.

877{
878 assert(node->dialect() == TFDialect::get());
879 assert(dynamic_cast<const TFNode *>(node) != nullptr);
880
881 ShapeInferenceAlgorithm alg{ctx};
882 auto shape = loco::must_cast<const TFNode *>(node)->accept(&alg);
883
884 if (shape.domain() == loco::Domain::Unknown)
885 sink->fail();
886 else
887 sink->okay(shape);
888}
virtual const Dialect * dialect(void) const =0
Return "Dialect" identifier that this node belongs to.
static loco::Dialect * get(void)
Definition TFDialect.cpp:84

References loco::Node::dialect(), moco::TFDialect::get(), and loco::Unknown.

◆ infer() [2/2]

bool moco::TFShapeInferenceRule::infer ( const loco::Node ,
loco::NodeShape  
) const
finalvirtual

Infer node's shape.

WARNING!!

Implementation SHOULD return true only when it succeeds in inference!

Implements loco::ShapeInferenceRule.

Definition at line 857 of file TFShapeInferenceRule.cpp.

858{
859 ::compat::Context ctx;
860 ::compat::Sink sink;
861
862 infer(&ctx, node, &sink);
863
864 assert(sink.status() == ::compat::Sink::Okay or sink.status() == ::compat::Sink::Fail);
865
866 if (sink.status() == ::compat::Sink::Fail)
867 {
868 return false;
869 }
870
871 shape = sink.shape();
872
873 return true;
874}
Definition infer.py:1

◆ recognize()

bool moco::TFShapeInferenceRule::recognize ( const loco::Dialect ) const
finalvirtual

Return true if this rule recognizes a given dialect.

Implements loco::ShapeInferenceRule.

Definition at line 851 of file TFShapeInferenceRule.cpp.

852{
853 // handle only TensorFlow dialect
854 return TFDialect::get() == d;
855}

References moco::TFDialect::get().

◆ support()

bool moco::TFShapeInferenceRule::support ( const API api) const
finalvirtual

Check whether a given API is available or not.

Reimplemented from loco::ShapeInferenceRule.

Definition at line 846 of file TFShapeInferenceRule.cpp.

References loco::ShapeInferenceRule::V1, and loco::ShapeInferenceRule::V2.


The documentation for this struct was generated from the following files: