ONE - On-device Neural Engine
Loading...
Searching...
No Matches
logo::RemoveForwardNodePass Struct Referencefinal

Use the input of "Forward" node instead. More...

#include <RemoveForwardNodePass.h>

Collaboration diagram for logo::RemoveForwardNodePass:

Public Member Functions

const char * name (void) const final
 
bool run (loco::Graph *g) final
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Detailed Description

Use the input of "Forward" node instead.

BEFORE: [X] -> [Forward] -> [Y]

AFTER: [X] -> [Y] [Forward]

NOTE This transform does not remove "Forward" node

Definition at line 37 of file RemoveForwardNodePass.h.

Member Function Documentation

◆ name()

const char * logo::RemoveForwardNodePass::name ( void  ) const
inlinefinalvirtual

Reimplemented from logo::Pass.

Definition at line 39 of file RemoveForwardNodePass.h.

39{ return "RemoveForwardNodePass"; }

◆ run()

bool logo::RemoveForwardNodePass::run ( loco::Graph graph)
finalvirtual

Run the pass.

Returns
false if there was nothing changed

Implements logo::Pass.

Definition at line 27 of file RemoveForwardNodePass.cpp.

28{
29 struct Collector final : public loco::CanonicalNodeMutableVisitor<void>
30 {
31 void visit(loco::Forward *node) final
32 {
33 if (node->input() != nullptr)
34 {
35 candidates.insert(node);
36 }
37 }
38
39 void visit(loco::Node *) final { return; }
40
41 std::set<loco::Forward *> candidates;
42 };
43
44 Collector collector;
45
46 for (auto node : loco::all_nodes(g))
47 {
48 if (node->dialect() == loco::CanonicalDialect::get())
49 {
50 auto canonical_node = loco::must_cast<loco::CanonicalNode *>(node);
51 canonical_node->accept(&collector);
52 }
53 }
54
55 for (auto node : collector.candidates)
56 {
57 replace(node).with(node->input());
58 node->input(nullptr);
59 }
60
61 return collector.candidates.size() > 0;
62}
static Dialect * get(void)
Create a new value identical to its input.
Definition Nodes.h:146
Logical unit of computation.
Definition Node.h:54
void with(Node *into) const
Definition Node.cpp:66
std::set< Node * > all_nodes(Graph *)
Enumerate all the nodes in a given graph.
Definition Graph.cpp:59
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82

References loco::all_nodes(), loco::CanonicalDialect::get(), loco::Forward::input(), and loco::Subst< SubstQualifier::Default >::with().

Referenced by package.infer.session::inference().


The documentation for this struct was generated from the following files: