ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::RemoveTFIdentityNode Struct Referencefinal

Use the input of "TFIdentity" node instead. More...

#include <RemoveTFIdentityNode.h>

Collaboration diagram for moco::RemoveTFIdentityNode:

Public Member Functions

const char * name (void) const final
 
bool run (loco::Graph *g) final
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Detailed Description

Use the input of "TFIdentity" node instead.

BEFORE: [X] -> [TFIdentity] -> [Y]

AFTER: [X] -> [Y] [TFIdentity]

NOTE This transform does not remove "TFIdentity" node This transform is identical to RemoveForwardNode

Definition at line 40 of file RemoveTFIdentityNode.h.

Member Function Documentation

◆ name()

const char * moco::RemoveTFIdentityNode::name ( void  ) const
inlinefinalvirtual

Reimplemented from logo::Pass.

Definition at line 42 of file RemoveTFIdentityNode.h.

42{ return "RemoveTFIdentityNode"; }

◆ run()

bool moco::RemoveTFIdentityNode::run ( loco::Graph graph)
finalvirtual

Run the pass.

Returns
false if there was nothing changed

Implements logo::Pass.

Definition at line 27 of file RemoveTFIdentityNode.cpp.

28{
29 struct Collector final : public moco::TFNodeMutableVisitor<void>
30 {
31 void visit(moco::TFIdentity *node) final
32 {
33 if (node->input() != nullptr)
34 {
35 candidates.insert(node);
36 }
37 }
38
39 void visit(moco::TFNode *) final { return; }
40
41 std::set<moco::TFIdentity *> candidates;
42 };
43
44 Collector collector;
45
46 for (auto node : loco::all_nodes(g))
47 {
48 if (node->dialect() == moco::TFDialect::get())
49 {
50 auto tf_node = dynamic_cast<moco::TFNode *>(node);
51 // NOTE our analysis tool reports an error for tf_node may be nullptr
52 if (tf_node != nullptr)
53 tf_node->accept(&collector);
54 }
55 }
56
57 for (auto node : collector.candidates)
58 {
59 replace(node).with(node->input());
60 node->input(nullptr);
61 }
62
63 return collector.candidates.size() > 0;
64}
void with(Node *into) const
Definition Node.cpp:66
static loco::Dialect * get(void)
Definition TFDialect.cpp:84
std::set< Node * > all_nodes(Graph *)
Enumerate all the nodes in a given graph.
Definition Graph.cpp:59
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
T accept(TFNodeVisitorBase< T > *) const
Definition TFNodeImpl.h:28

References moco::TFNode::accept(), loco::all_nodes(), moco::TFDialect::get(), moco::TFIdentity::input(), and loco::Subst< SubstQualifier::Default >::with().

Referenced by package.infer.session::inference().


The documentation for this struct was generated from the following files: