29 const std::vector<std::size_t> &order2)
31 assert(order1.size() == order2.size());
32 std::vector<size_t> res(order1.size());
33 for (
size_t i = 0; i < order1.size(); i++)
35 res[order2[order1[i]]] = i;
54 auto g =
static_cast<Graph *
>(data);
57 auto is_tr = [](
const Operation *op1) {
return op1->getType() == Operation::Type::transpose; };
58 std::vector<std::pair<Operation *, Operation *>> matches = matcher.
matchEdge(is_tr, is_tr);
59 std::unordered_set<Operation *> deleted_nodes;
60 while (!matches.empty())
62 for (std::pair<Operation *, Operation *> match : matches)
64 if (deleted_nodes.find(match.first) != deleted_nodes.end())
69 if (deleted_nodes.find(match.second) != deleted_nodes.end())
74 auto combined_axis_order =
75 combineAxisOrders(top_transpose->getAxisOrder(), bottom_transpose->getAxisOrder());
77 if (!isIdentityTranspose(combined_axis_order))
82 g->replaceNode(bottom_transpose, new_tr_op);
88 g->replaceNode(bottom_transpose, top);
90 deleted_nodes.emplace(bottom_transpose);
91 if (top_transpose->getOutput(0)->getUses().empty())
93 g->removeNode(top_transpose);
94 deleted_nodes.emplace(top_transpose);
97 matches = matcher.
matchEdge(is_tr, is_tr);