ONE - On-device Neural Engine
Loading...
Searching...
No Matches
logo::SimplifyDomainConversionPass Struct Referencefinal

Simplify redundant domain conversion. More...

#include <SimplifyDomainConversionPass.h>

Collaboration diagram for logo::SimplifyDomainConversionPass:

Public Member Functions

const char * name (void) const final
 
bool run (loco::Graph *g) final
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Detailed Description

Simplify redundant domain conversion.

SimplifyDomainConversionPass recognizes the following patterns:

  • FeatureDecode followed by FeatureEncode (Feature -> Tensor -> Feature)
  • FeatureEncode followed by FeatureDecode (Tensor -> Feature -> Tensor)
  • FilterEncode followed by FilterDecode (Tensor -> Filter -> Tensor)
  • BiasEncode followed by BiasDecode (Tensor -> Bias -> Tensor)
  • DepthwiseFilterEncode followed by DepthwiseFilterDecode (Tensor -> DepthwiseFilter -> Tensor)
  • MatrixDecode followed by MatrixEncode (Matrix -> Tensor -> Matrix)
  • MatrixEncode followed by MatrixDecode (Tensor -> Matrix -> Tensor)
  • (TO BE ADDED)

Definition at line 38 of file SimplifyDomainConversionPass.h.

Member Function Documentation

◆ name()

const char * logo::SimplifyDomainConversionPass::name ( void  ) const
inlinefinalvirtual

Reimplemented from logo::Pass.

Definition at line 40 of file SimplifyDomainConversionPass.h.

40{ return "SimplifyDomainConversionPass"; }

◆ run()

bool logo::SimplifyDomainConversionPass::run ( loco::Graph graph)
finalvirtual

Run the pass.

Returns
false if there was nothing changed

Implements logo::Pass.

Definition at line 115 of file SimplifyDomainConversionPass.cpp.

116{
117 // TODO Introduce and Use "Pattern Match"
118 struct Collector final : public loco::CanonicalNodeMutableVisitor<void>
119 {
120 // Let's find FeatureDecode followed by FeatureEncode
121 void visit(loco::FeatureEncode *encode_node) final
122 {
123 using namespace loco;
124
125 auto decode_node = dynamic_cast<loco::FeatureDecode *>(encode_node->input());
126 if (decode_node == nullptr)
127 {
128 return;
129 }
130 assert(decode_node->input() != nullptr);
131
132 auto encoder = encode_node->encoder();
133 assert(encoder != nullptr);
134
135 auto decoder = decode_node->decoder();
136 assert(decoder != nullptr);
137
138 // NOTE Work only for permuting codec
139 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Feature> *>(decoder);
140 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Feature> *>(encoder);
141
142 if (perm_encoder == nullptr || perm_decoder == nullptr)
143 {
144 return;
145 }
146
147 if (equal(perm_encoder->perm(), perm_decoder->perm()))
148 {
149 forwardCandidates.insert({encode_node, decode_node->input()});
150 }
151 }
152
153 // Let's find `FeatureEncode -- FeatureDecode` pattern
154 void visit(loco::FeatureDecode *decode_node) final
155 {
156 using namespace loco;
157
158 auto encode_node = dynamic_cast<loco::FeatureEncode *>(decode_node->input());
159 if (encode_node == nullptr)
160 {
161 return;
162 }
163 assert(encode_node->input() != nullptr);
164
165 auto encoder = encode_node->encoder();
166 assert(encoder != nullptr);
167
168 auto decoder = decode_node->decoder();
169 assert(decoder != nullptr);
170
171 // NOTE Work only for permuting codec
172 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Feature> *>(decoder);
173 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Feature> *>(encoder);
174
175 if (perm_encoder == nullptr || perm_decoder == nullptr)
176 {
177 return;
178 }
179
180 if (equal(perm_encoder->perm(), perm_decoder->perm()))
181 {
182 forwardCandidates.insert({decode_node, encode_node->input()});
183 }
184 }
185
186 // Let's find `FilterEncode -- FilterDecode` pattern
187 void visit(loco::FilterDecode *decode_node) final
188 {
189 using namespace loco;
190
191 auto encode_node = dynamic_cast<loco::FilterEncode *>(decode_node->input());
192 if (encode_node == nullptr)
193 {
194 return;
195 }
196 assert(encode_node->input() != nullptr);
197
198 auto encoder = encode_node->encoder();
199 assert(encoder != nullptr);
200
201 auto decoder = decode_node->decoder();
202 assert(decoder != nullptr);
203
204 // NOTE Work only for permuting codec
205 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Filter> *>(decoder);
206 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Filter> *>(encoder);
207
208 if (perm_encoder == nullptr || perm_decoder == nullptr)
209 {
210 return;
211 }
212
213 if (equal(perm_encoder->perm(), perm_decoder->perm()))
214 {
215 forwardCandidates.insert({decode_node, encode_node->input()});
216 }
217 else
218 {
219 std::vector<loco::TensorAxis> perm_vec;
220 perm_vec.resize(4);
221
222 auto enc_perm = perm_encoder->perm();
223 auto dec_perm = perm_decoder->perm();
224
225 for (const auto &axis :
226 {FilterAxis::Count, FilterAxis::Height, FilterAxis::Width, FilterAxis::Depth})
227 {
228 auto from = enc_perm->axis(axis);
229 auto to = dec_perm->axis(axis);
230 perm_vec[to] = from;
231 }
232
233 transposeCandidates.insert(
234 std::make_unique<TransposeCtx>(encode_node, decode_node, encode_node->input(), perm_vec));
235 }
236 }
237
238 // Let's find `BiasEncode -- BiasDecode` pattern
239 void visit(loco::BiasDecode *decode_node) final
240 {
241 if (auto encode_node = dynamic_cast<loco::BiasEncode *>(decode_node->input()))
242 {
243 assert(encode_node->input() != nullptr);
244 forwardCandidates.insert({decode_node, encode_node->input()});
245 }
246 }
247
248 // Let's find `DepthwiseFilterEncode -- DepthwiseFilterDecode` pattern
249 void visit(loco::DepthwiseFilterDecode *decode_node) final
250 {
251 using namespace loco;
252
253 auto encode_node = dynamic_cast<loco::DepthwiseFilterEncode *>(decode_node->input());
254 if (encode_node == nullptr)
255 {
256 return;
257 }
258 assert(encode_node->input() != nullptr);
259
260 auto encoder = encode_node->encoder();
261 assert(encoder != nullptr);
262
263 auto decoder = decode_node->decoder();
264 assert(decoder != nullptr);
265
266 // NOTE Work only for permuting codec
267 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::DepthwiseFilter> *>(decoder);
268 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::DepthwiseFilter> *>(encoder);
269
270 if (perm_encoder == nullptr || perm_decoder == nullptr)
271 {
272 return;
273 }
274
275 if (equal(perm_encoder->perm(), perm_decoder->perm()))
276 {
277 forwardCandidates.insert({decode_node, encode_node->input()});
278 }
279 else
280 {
281 std::vector<TensorAxis> perm_vec;
282 perm_vec.resize(4);
283
284 auto enc_perm = perm_encoder->perm();
285 auto dec_perm = perm_decoder->perm();
286
287 for (const auto &axis : {DepthwiseFilterAxis::Depth, DepthwiseFilterAxis::Height,
288 DepthwiseFilterAxis::Width, DepthwiseFilterAxis::Multiplier})
289 {
290 auto from = enc_perm->axis(axis);
291 auto to = dec_perm->axis(axis);
292 perm_vec[to] = from;
293 }
294
295 transposeCandidates.insert(
296 std::make_unique<TransposeCtx>(encode_node, decode_node, encode_node->input(), perm_vec));
297 }
298 }
299
300 // Let's find MatrixDecode followed by MatrixEncode
301 void visit(loco::MatrixEncode *encode_node) final
302 {
303 using namespace loco;
304
305 auto decode_node = dynamic_cast<loco::MatrixDecode *>(encode_node->input());
306 if (decode_node == nullptr)
307 {
308 return;
309 }
310 assert(decode_node->input() != nullptr);
311
312 auto encoder = encode_node->encoder();
313 assert(encoder != nullptr);
314
315 auto decoder = decode_node->decoder();
316 assert(decoder != nullptr);
317
318 // NOTE Work only for permuting codec
319 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Matrix> *>(decoder);
320 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Matrix> *>(encoder);
321
322 if (perm_encoder == nullptr || perm_decoder == nullptr)
323 {
324 return;
325 }
326
327 if (equal(perm_encoder->perm(), perm_decoder->perm()))
328 {
329 forwardCandidates.insert({encode_node, decode_node->input()});
330 }
331 }
332
333 // Let's find MatrixEncode followed by MatrixDecode
334 void visit(loco::MatrixDecode *decode_node) final
335 {
336 using namespace loco;
337
338 auto encode_node = dynamic_cast<loco::MatrixEncode *>(decode_node->input());
339 if (encode_node == nullptr)
340 {
341 return;
342 }
343 assert(encode_node->input() != nullptr);
344
345 auto encoder = encode_node->encoder();
346 assert(encoder != nullptr);
347
348 auto decoder = decode_node->decoder();
349 assert(decoder != nullptr);
350
351 // NOTE Work only for permuting codec
352 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Matrix> *>(decoder);
353 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Matrix> *>(encoder);
354
355 if (perm_encoder == nullptr || perm_decoder == nullptr)
356 {
357 return;
358 }
359
360 if (equal(perm_encoder->perm(), perm_decoder->perm()))
361 {
362 forwardCandidates.insert({decode_node, encode_node->input()});
363 }
364 else
365 {
366 std::vector<loco::TensorAxis> perm_vec;
367 perm_vec.resize(2);
368
369 auto enc_perm = perm_encoder->perm();
370 auto dec_perm = perm_decoder->perm();
371
372 for (const auto &axis : {MatrixAxis::Height, MatrixAxis::Width})
373 {
374 auto from = enc_perm->axis(axis);
375 auto to = dec_perm->axis(axis);
376 perm_vec[to] = from;
377 }
378
379 transposeCandidates.insert(
380 std::make_unique<TransposeCtx>(encode_node, decode_node, encode_node->input(), perm_vec));
381 }
382 }
383
384 void visit(loco::Node *) final { return; }
385
386 using SimplifyingInfo = std::pair<loco::Node * /* end node of subgraph that will be replaced*/,
387 loco::Node * /* input of subgraph */>;
388 std::set<SimplifyingInfo> forwardCandidates;
389
390 struct TransposeCtx
391 {
392 loco::Node *first_node; // starting node of subgraph that will be replaced
393 loco::Node *last_node; // end node of subgraph that will be replaced
394 loco::Node *input_node; // input of subgraph
395 std::vector<loco::TensorAxis> perm_vec; // perm vector for transpose
396
397 TransposeCtx(loco::Node *first, loco::Node *last, loco::Node *input,
398 std::vector<loco::TensorAxis> perm)
399 : first_node(first), last_node(last), input_node(input), perm_vec(perm)
400 { /* empty */
401 }
402 };
403
404 std::set<std::unique_ptr<TransposeCtx>> transposeCandidates;
405 };
406
407 Collector collector;
408
409 for (auto node : loco::active_nodes(loco::output_nodes(g)))
410 {
411 if (node->dialect() == loco::CanonicalDialect::get())
412 {
413 auto canonical_node = loco::must_cast<loco::CanonicalNode *>(node);
414 canonical_node->accept(&collector);
415 }
416 }
417
418 for (auto p : collector.forwardCandidates)
419 {
420 auto forward_node = g->nodes()->create<loco::Forward>();
421 forward_node->input(p.second);
422 replace(p.first).with(forward_node);
423 set_input_null(p.first);
424 }
425
426 for (auto &ctx : collector.transposeCandidates)
427 {
428 auto transpose_node = g->nodes()->create<loco::TensorTranspose>();
429 {
430 transpose_node->perm()->size(ctx->perm_vec.size());
431
432 for (loco::TensorAxis axis = 0; axis < ctx->perm_vec.size(); axis++)
433 transpose_node->perm()->axis(axis) = ctx->perm_vec[axis];
434 }
435
436 transpose_node->input(ctx->input_node);
437 replace(ctx->last_node).with(transpose_node);
438 set_input_null(ctx->first_node);
439 }
440
441 return (collector.forwardCandidates.size() > 0 or collector.transposeCandidates.size() > 0);
442}
Create a "Tensor" from a "Bias".
Definition Nodes.h:743
Create a "Bias" from a "Tensor".
Definition Nodes.h:758
static Dialect * get(void)
Create a tensor from a depthwise filter.
Definition Nodes.h:475
Create a depthwise filter from a tensor.
Definition Nodes.h:456
Create a tensor from a feature map.
Definition Nodes.h:399
Create a feature map from a tensor.
Definition Nodes.h:380
Create a tensor from a filter.
Definition Nodes.h:437
Create a filter from a tensor.
Definition Nodes.h:418
Create a new value identical to its input.
Definition Nodes.h:146
Node * input(void) const
Definition Nodes.h:151
Create Tensor from Matrix.
Definition Nodes.h:1042
Create Matrix from Tensor.
Definition Nodes.h:1018
Logical unit of computation.
Definition Node.h:54
void with(Node *into) const
Definition Node.cpp:66
uint32_t size() const
Definition Nodes.h:1104
Permute an input.
Definition Nodes.h:1090
Perm * perm(void)
Definition Nodes.h:1114
uint32_t TensorAxis
Definition TensorAxis.h:25
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.

References loco::active_nodes(), loco::FeatureDecode::decoder(), loco::Node::dialect(), loco::FeatureEncode::encoder(), loco::CanonicalDialect::get(), loco::Forward::input(), loco::FeatureEncode::input(), loco::FeatureDecode::input(), loco::output_nodes(), loco::TensorTranspose::perm(), loco::replace(), loco::TensorTranspose::Perm::size(), and loco::Subst< SubstQualifier::Default >::with().

Referenced by package.infer.session::inference().


The documentation for this struct was generated from the following files: