ONE - On-device Neural Engine
Loading...
Searching...
No Matches
CombineTransposes.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18#include "mir/ops/TransposeOp.h"
19#include "mir/Graph.h"
21#include <algorithm>
22
23namespace nnc
24{
25
26using namespace mir;
27
28std::vector<size_t> combineAxisOrders(const std::vector<std::size_t> &order1,
29 const std::vector<std::size_t> &order2)
30{
31 assert(order1.size() == order2.size());
32 std::vector<size_t> res(order1.size());
33 for (size_t i = 0; i < order1.size(); i++)
34 {
35 res[order2[order1[i]]] = i;
36 }
37 return res;
38}
39
40static bool isIdentityTranspose(const std::vector<size_t> &axis_order)
41{
42 for (size_t i = 0; i < (axis_order.size()); i++)
43 {
44 if (axis_order[i] != i)
45 {
46 return false;
47 }
48 }
49 return true;
50}
51
53{
54 auto g = static_cast<Graph *>(data);
55 assert(g);
56 GraphPatternMatcher matcher(g);
57 auto is_tr = [](const Operation *op1) { return op1->getType() == Operation::Type::transpose; };
58 std::vector<std::pair<Operation *, Operation *>> matches = matcher.matchEdge(is_tr, is_tr);
59 std::unordered_set<Operation *> deleted_nodes;
60 while (!matches.empty())
61 {
62 for (std::pair<Operation *, Operation *> match : matches)
63 {
64 if (deleted_nodes.find(match.first) != deleted_nodes.end())
65 {
66 break;
67 };
68 auto *top_transpose = dynamic_cast<mir::ops::TransposeOp *>(match.first);
69 if (deleted_nodes.find(match.second) != deleted_nodes.end())
70 {
71 break;
72 };
73 auto *bottom_transpose = dynamic_cast<mir::ops::TransposeOp *>(match.second);
74 auto combined_axis_order =
75 combineAxisOrders(top_transpose->getAxisOrder(), bottom_transpose->getAxisOrder());
76
77 if (!isIdentityTranspose(combined_axis_order))
78 {
79 auto new_tr_op =
80 g->create<mir::ops::TransposeOp>(top_transpose->getInput(0), combined_axis_order);
81
82 g->replaceNode(bottom_transpose, new_tr_op);
83 }
84 else
85 {
86 // Connect top input to all outputs of bottom
87 Operation *top = top_transpose->getInput(0)->getNode();
88 g->replaceNode(bottom_transpose, top);
89 }
90 deleted_nodes.emplace(bottom_transpose);
91 if (top_transpose->getOutput(0)->getUses().empty())
92 {
93 g->removeNode(top_transpose);
94 deleted_nodes.emplace(top_transpose);
95 }
96 }
97 matches = matcher.matchEdge(is_tr, is_tr);
98 };
99 return g;
100}
101
102} // namespace nnc
std::vector< std::pair< Operation *, Operation * > > matchEdge(Predicate p1, Predicate p2)
Match an edge with 2 predicates for ends of the edge.
Operation * getNode()
Returns the node this is an output of.
Definition Operation.h:72
Output * getInput(std::size_t index)
Definition Operation.h:137
Tensor transpose operation.
Definition TransposeOp.h:34
PassData run(PassData data) override
run compiler pass
class that encapsulate value returned and taken by pass
Definition PassData.h:30
std::vector< size_t > combineAxisOrders(const std::vector< std::size_t > &order1, const std::vector< std::size_t > &order2)