ONE - On-device Neural Engine
Loading...
Searching...
No Matches
exo::MergeConcatNodesPass Class Reference

Merge concat nodes whose axis and fusedActivationFunction are same. More...

#include <MergeConcatNodesPass.h>

Collaboration diagram for exo::MergeConcatNodesPass:

Public Member Functions

virtual const char * name (void) const
 
bool run (loco::Graph *graph)
 Merge TFLConcatenate nodes whose axis and fusedActivationFunction are same.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Detailed Description

Merge concat nodes whose axis and fusedActivationFunction are same.

Definition at line 30 of file MergeConcatNodesPass.h.

Member Function Documentation

◆ name()

virtual const char * exo::MergeConcatNodesPass::name ( void  ) const
inlinevirtual

Reimplemented from logo::Pass.

Definition at line 33 of file MergeConcatNodesPass.h.

33{ return "exo::MergeConcatNodesPass"; }

◆ run()

bool exo::MergeConcatNodesPass::run ( loco::Graph graph)
virtual

Merge TFLConcatenate nodes whose axis and fusedActivationFunction are same.

[Before] in:0 ----------------------------—\ in:1 -— TFLConcatenation:0 -----— TFLConcatenation:3 — C (axis = 0, NONE) (axis = 0, NONE) in:2 —/ / in:3 -— TFLConcatenation:1 ---—/ (axis = 1, NONE) / in:4 —/ / in:5 -— TFLConcatenation:2 —/ (axis = 0, RELU) in:6 —/

[After] in:0 ----------------------------—\ in:1 -----------------------------— TFLConcatenation:4 — C (axis = 0, NONE) in:2 ----------------------------—/ in:3 -— TFLConcatenation:1 ---—/ (axis = 1, NONE) / in:4 —/ / in:5 -— TFLConcatenation:2 —/ (axis = 0, RELU) in:6 —/

in:1 -— TFLConcatenation:0 -— (axis = 0, NONE) in:2 —/

    ---- TFLConcatenation:3 ----
         (axis = 0, NONE)

Implements logo::Pass.

Definition at line 147 of file MergeConcatNodesPass.cpp.

148{
149 // Let's enumerate nodes required to compute output nodes
151
152 // Find TFLConcatenation nodes which have another TFLConcatenation nodes
153 // as inputs, with same axis and same fusedActivationFunction
154 std::vector<locoex::TFLConcatenation *> candidates;
155 for (auto node : active_nodes)
156 {
157 if (auto concat = dynamic_cast<locoex::TFLConcatenation *>(node))
158 {
159 for (uint32_t i = 0; i < concat->numValues(); ++i)
160 {
161 auto input = dynamic_cast<locoex::TFLConcatenation *>(concat->values(i));
162 if (input != nullptr && canMerge(input, concat))
163 {
164 candidates.push_back(concat);
165 break;
166 }
167 }
168 }
169 }
170
171 // Merge multiple TFLConcatenation nodes as one TFLConcatenation node
172 for (auto node : candidates)
173 {
174 auto inputs = dfs(node);
175
176 auto new_concat = graph->nodes()->create<locoex::TFLConcatenation>(inputs.size());
177 new_concat->axis(node->axis());
178 new_concat->fusedActivationFunction(node->fusedActivationFunction());
179
180 for (uint32_t i = 0; i < inputs.size(); ++i)
181 new_concat->values(i, inputs.at(i));
182
183 loco::replace(node).with(new_concat);
184 for (uint32_t i = 0; i < node->numValues(); ++i)
185 node->values(i, nullptr);
186 }
187
188 return candidates.size() > 0;
189}
void with(Node *into) const
Definition Node.cpp:66
CONCATENATION in TensorFlow Lite.
Definition TFLNodes.h:160
uint32_t axis(void) const
Definition TFLNodes.h:184
void concat(std::ostream &os, const std::string &sep, It beg, It end)
Definition String.h:31
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82

References loco::active_nodes(), locoex::TFLConcatenation::axis(), loco::output_nodes(), loco::replace(), and loco::Subst< SubstQualifier::Default >::with().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: