ONE - On-device Neural Engine
Loading...
Searching...
No Matches
exo::TensorBroadcastConverter Class Reference

Pass to resolve TensorBroadcast IR. More...

#include <TensorBroadcastConverter.h>

Collaboration diagram for exo::TensorBroadcastConverter:

Public Member Functions

virtual const char * name (void) const
 
bool run (loco::Graph *graph)
 Disconnects loco::TensorBroadcast from the graph if following node is one of binary node: TFLAdd, TFLSub, TFLMul, TFLDiv, TFLMaximum and meets condition (TBA)
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Detailed Description

Pass to resolve TensorBroadcast IR.

Definition at line 29 of file TensorBroadcastConverter.h.

Member Function Documentation

◆ name()

virtual const char * exo::TensorBroadcastConverter::name ( void  ) const
inlinevirtual

Reimplemented from logo::Pass.

Definition at line 32 of file TensorBroadcastConverter.h.

32{ return "exo::TensorBroadcastConverter"; }

◆ run()

bool exo::TensorBroadcastConverter::run ( loco::Graph graph)
virtual

Disconnects loco::TensorBroadcast from the graph if following node is one of binary node: TFLAdd, TFLSub, TFLMul, TFLDiv, TFLMaximum and meets condition (TBA)

Note
Before: x — TensorBroadcast — TFLXXX — output y -------------------—/

After: — TensorBroadcast — x — TFLXXX — output y –/

Implements logo::Pass.

Definition at line 132 of file TensorBroadcastConverter.cpp.

133{
134 Collector collector;
135
137
138 for (auto node : active_nodes)
139 {
140 if (node->dialect() == locoex::TFLDialect::get())
141 {
142 auto tfl_node = loco::must_cast<locoex::TFLNode *>(node);
143 tfl_node->accept(&collector);
144 }
145 }
146
147 bool changed = false;
148
149 for (auto pair : collector.candidates)
150 {
151 if (mapping_condition(pair))
152 {
153 loco::TensorBroadcast *tensorbroadcast = pair.first;
154 if (auto tfladd = dynamic_cast<locoex::TFLAdd *>(pair.second))
155 {
156 jump_connection<locoex::TFLAdd>(tensorbroadcast, tfladd);
157 changed = true;
158 }
159 else if (auto tfldiv = dynamic_cast<locoex::TFLDiv *>(pair.second))
160 {
161 jump_connection<locoex::TFLDiv>(tensorbroadcast, tfldiv);
162 changed = true;
163 }
164 else if (auto tflmul = dynamic_cast<locoex::TFLMul *>(pair.second))
165 {
166 jump_connection<locoex::TFLMul>(tensorbroadcast, tflmul);
167 changed = true;
168 }
169 else if (auto tflsub = dynamic_cast<locoex::TFLSub *>(pair.second))
170 {
171 jump_connection<locoex::TFLSub>(tensorbroadcast, tflsub);
172 changed = true;
173 }
174 else if (auto tflmaximum = dynamic_cast<locoex::TFLMaximum *>(pair.second))
175 {
176 jump_connection<locoex::TFLMaximum>(tensorbroadcast, tflmaximum);
177 changed = true;
178 }
179 else
180 {
181 assert(false);
182 }
183 }
184 }
185
186 return changed;
187}
Duplicate elements along specified axes.
Definition Nodes.h:980
ADD in TensorFlow Lite.
Definition TFLNodes.h:116
static loco::Dialect * get(void)
DIV in TensorFlow Lite.
Definition TFLNodes.h:280
MAXIMUM in TensorFlow Lite.
Definition TFLNodes.h:314
MUL in TensorFlow Lite.
Definition TFLNodes.h:375
SUB in TensorFlow Lite.
Definition TFLNodes.h:488
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101

References loco::active_nodes(), locoex::TFLDialect::get(), and loco::output_nodes().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: