ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnc::CombineTransposes Class Reference

This pass combines sequential transposes and removes identity transposes if the combination results in an identity permutation. More...

#include <CombineTransposes.h>

Collaboration diagram for nnc::CombineTransposes:

Public Member Functions

PassData run (PassData data) override
 run compiler pass
 
std::string getName () override
 
- Public Member Functions inherited from nnc::Pass
virtual void cleanup ()
 clean compiler pass data
 
virtual ~Pass ()=default
 

Detailed Description

This pass combines sequential transposes and removes identity transposes if the combination results in an identity permutation.

Definition at line 30 of file CombineTransposes.h.

Member Function Documentation

◆ getName()

std::string nnc::CombineTransposes::getName ( )
inlineoverridevirtual

Reimplemented from nnc::Pass.

Definition at line 35 of file CombineTransposes.h.

35{ return "opt_combine_transposes"; };

◆ run()

nnc::PassData nnc::CombineTransposes::run ( PassData  data)
overridevirtual

run compiler pass

Parameters
data- data that pass is taken
Returns
data that can be passed to the next pass
Exceptions
PassExceptionobject if errors occured

Implements nnc::Pass.

Definition at line 52 of file CombineTransposes.cpp.

53{
54 auto g = static_cast<Graph *>(data);
55 assert(g);
56 GraphPatternMatcher matcher(g);
57 auto is_tr = [](const Operation *op1) { return op1->getType() == Operation::Type::transpose; };
58 std::vector<std::pair<Operation *, Operation *>> matches = matcher.matchEdge(is_tr, is_tr);
59 std::unordered_set<Operation *> deleted_nodes;
60 while (!matches.empty())
61 {
62 for (std::pair<Operation *, Operation *> match : matches)
63 {
64 if (deleted_nodes.find(match.first) != deleted_nodes.end())
65 {
66 break;
67 };
68 auto *top_transpose = dynamic_cast<mir::ops::TransposeOp *>(match.first);
69 if (deleted_nodes.find(match.second) != deleted_nodes.end())
70 {
71 break;
72 };
73 auto *bottom_transpose = dynamic_cast<mir::ops::TransposeOp *>(match.second);
74 auto combined_axis_order =
75 combineAxisOrders(top_transpose->getAxisOrder(), bottom_transpose->getAxisOrder());
76
77 if (!isIdentityTranspose(combined_axis_order))
78 {
79 auto new_tr_op =
80 g->create<mir::ops::TransposeOp>(top_transpose->getInput(0), combined_axis_order);
81
82 g->replaceNode(bottom_transpose, new_tr_op);
83 }
84 else
85 {
86 // Connect top input to all outputs of bottom
87 Operation *top = top_transpose->getInput(0)->getNode();
88 g->replaceNode(bottom_transpose, top);
89 }
90 deleted_nodes.emplace(bottom_transpose);
91 if (top_transpose->getOutput(0)->getUses().empty())
92 {
93 g->removeNode(top_transpose);
94 deleted_nodes.emplace(top_transpose);
95 }
96 }
97 matches = matcher.matchEdge(is_tr, is_tr);
98 };
99 return g;
100}
Tensor transpose operation.
Definition TransposeOp.h:34
std::vector< size_t > combineAxisOrders(const std::vector< std::size_t > &order1, const std::vector< std::size_t > &order2)

References nnc::combineAxisOrders(), mir::Operation::getInput(), mir::Operation::Output::getNode(), and mir::GraphPatternMatcher::matchEdge().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: