ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnc::SinkTranspose Class Reference

This pass sinks transposes below Relu and Concat nodes (in that order). ‘concat(relu(tr(x)), relu(tr(y))) -> tr(concat’(relu(x), relu(y)))`. More...

#include <SinkTranspose.h>

Collaboration diagram for nnc::SinkTranspose:

Public Member Functions

PassData run (PassData data) override
 run compiler pass
 
std::string getName () override
 
- Public Member Functions inherited from nnc::Pass
virtual void cleanup ()
 clean compiler pass data
 
virtual ~Pass ()=default
 

Detailed Description

This pass sinks transposes below Relu and Concat nodes (in that order). ‘concat(relu(tr(x)), relu(tr(y))) -> tr(concat’(relu(x), relu(y)))`.

Definition at line 30 of file SinkTranspose.h.

Member Function Documentation

◆ getName()

std::string nnc::SinkTranspose::getName ( )
inlineoverridevirtual

Reimplemented from nnc::Pass.

Definition at line 35 of file SinkTranspose.h.

35{ return "SinkTranspose"; };

◆ run()

PassData nnc::SinkTranspose::run ( PassData  data)
overridevirtual

run compiler pass

Parameters
data- data that pass is taken
Returns
data that can be passed to the next pass
Exceptions
PassExceptionobject if errors occured

Implements nnc::Pass.

Definition at line 34 of file SinkTranspose.cpp.

35{
36 auto g = static_cast<Graph *>(data);
37 assert(g); // NOLINT
38 GraphPatternMatcher matcher(g);
39 auto is_tr = [](const Operation *op1) { return op1->getType() == Operation::Type::transpose; };
40 auto is_relu = [](const Operation *op2) { return op2->getType() == Operation::Type::ReLU; };
41 auto is_concat = [](const Operation *op2) { return op2->getType() == Operation::Type::concat; };
42 std::vector<std::pair<Operation *, Operation *>> matches;
43
44 // sink transpose below ReLU
45 matches = matcher.matchEdge(is_tr, is_relu);
46 for (auto pair : matches)
47 {
48 swapAdjacent(g, pair.first, pair.second);
49 }
50
51 // sink transpose through Concat
52 auto v_matches = matcher.matchUpBush(is_tr, is_concat);
53 for (const auto &pair : v_matches)
54 {
55 std::vector<Operation *> trs = pair.first;
56 auto *concat = dynamic_cast<ops::ConcatOp *>(pair.second);
57 auto axis_order = dynamic_cast<ops::TransposeOp *>(trs[0])->getAxisOrder();
58 if (std::all_of(trs.begin(), trs.end(), [&axis_order](Operation *tr) {
59 return dynamic_cast<ops::TransposeOp *>(tr)->getAxisOrder() == axis_order;
60 }))
61 {
62 std::vector<Operation::Output *> prev_trans;
63 prev_trans.reserve(trs.size());
64 for (auto transpose : trs)
65 {
66 prev_trans.emplace_back(transpose->getInput(0));
67 }
68 auto new_concat = g->create<ops::ConcatOp>(prev_trans, axis_order[concat->getAxis()]);
69 auto new_transpose = g->create<ops::TransposeOp>(new_concat->getOutput(0), axis_order);
70 // removes old concat
71 g->replaceNode(concat, new_transpose);
72 for (auto tr : trs)
73 {
74 removeNodeIfUnused(g, tr);
75 }
76 }
77 }
78
79 return g;
80}
Description of tensor concatenation operation.
Definition ConcatOp.h:31
Tensor transpose operation.
Definition TransposeOp.h:34
void concat(std::ostream &os, const std::string &sep, It beg, It end)
Definition String.h:31
void swapAdjacent(mir::Graph *g, mir::Operation *top, mir::Operation *bottom)
Swap adjacent nodes in Graph. Creates new nodes and replaces the old ones with new.
void removeNodeIfUnused(mir::Graph *g, mir::Operation *op)

References mir::GraphPatternMatcher::matchEdge(), mir::GraphPatternMatcher::matchUpBush(), nnc::opt_util::removeNodeIfUnused(), and nnc::opt_util::swapAdjacent().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: