ONE - On-device Neural Engine
Loading...
Searching...
No Matches
CombineTransposes.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18#include "mir/ops/TransposeOp.h"
19#include "mir/ops/ReluOp.h"
20#include "mir/ops/OutputOp.h"
21#include "Util.h"
22#include <gtest/gtest.h>
23
24using namespace std;
25using namespace nnc;
26using namespace mir;
27
28namespace
29{
30
31TEST(OptPass, eliminateTransposesLinear)
32{
34 /* Create graph:
35 * [input]
36 * ||
37 * [Transpose 1]
38 * ||
39 * [Transpose 2]
40 * ||
41 * [relu]
42 */
43 mir::TensorType input_type{mir::DataType::FLOAT32, Shape{1, 2, 3}};
44 Operation *input = g.create<ops::InputOp>(input_type);
45 Operation *tr1 = g.create<ops::TransposeOp>(input->getOutput(0), vector<size_t>{1, 0, 2});
46 Operation *tr15 = g.create<ops::TransposeOp>(tr1->getOutput(0), vector<size_t>{1, 0, 2});
47 Operation *tr2 = g.create<ops::TransposeOp>(tr15->getOutput(0), vector<size_t>{1, 0, 2});
48 Operation *relu = g.create<ops::ReluOp>(tr2->getOutput(0));
49
50 // Check that layout is desired
51 std::stringstream ss;
52 DumpVisitor d(ss);
54 pass.run(&g);
55 g.accept(&d);
56 // Assert only 1 transpose remains
57 ASSERT_EQ("i_0.t_1.r_4.", ss.str());
58}
59
60TEST(OptPass, combineTransposesLinear)
61{
63 /* Create graph:
64 * [input]
65 * ||
66 * [Transpose 1]
67 * ||
68 * [Transpose 2]
69 * ||
70 * [relu]
71 */
72 mir::TensorType input_type{mir::DataType::FLOAT32, Shape{1, 2, 3}};
73 Operation *input = g.create<ops::InputOp>(input_type);
74 Operation *tr1 = g.create<ops::TransposeOp>(input->getOutput(0), vector<size_t>{1, 0, 2});
75 Operation *tr2 = g.create<ops::TransposeOp>(tr1->getOutput(0), vector<size_t>{0, 2, 1});
76 Operation *relu = g.create<ops::ReluOp>(tr2->getOutput(0));
77
78 std::stringstream ss;
79 DumpVisitor d(ss);
81 pass.run(&g);
82 g.accept(&d);
83
84 // Assert transposes are combined
85 ASSERT_EQ("i_0.t_4.r_3.", ss.str());
86 Operation::Use use = g.getInputs()[0]->getOutput(0)->getUses()[0];
87 auto ax_ord_actual = dynamic_cast<ops::TransposeOp *>(use.getNode())->getAxisOrder();
88 auto ax_ord_true = vector<size_t>{1, 2, 0};
89 ASSERT_TRUE(ax_ord_actual == ax_ord_true);
90}
91
92TEST(OptPass, combineTransposesBush)
93{
95 /* Create graph:
96 * [input]
97 * ||
98 * [Transpose 1]
99 * // \\
100 *[Transpose 2] [Transpose 3]
101 * \\ //
102 * [Add]
103 */
104 mir::TensorType input_type{mir::DataType::FLOAT32, Shape{1, 2, 3, 2}};
105 Operation *input = g.create<ops::InputOp>(input_type);
106 Operation *tr1 = g.create<ops::TransposeOp>(input->getOutput(0), vector<size_t>{1, 0, 2, 3});
107 Operation *tr2 = g.create<ops::TransposeOp>(tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
108 Operation *tr3 = g.create<ops::TransposeOp>(tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
109 Operation *elw = g.create<ops::AddOp>(tr2->getOutput(0), tr3->getOutput(0));
110 std::stringstream ss;
111 DumpVisitor d(ss);
113 pass.run(&g);
114 g.accept(&d);
115 ASSERT_EQ("i_0.b_4.", ss.str());
116 ASSERT_EQ(elw->getInput(0)->getNode()->getType(), mir::Operation::Type::input);
117 ASSERT_EQ(elw->getInput(1)->getNode()->getType(), mir::Operation::Type::input);
118}
119
120TEST(OptPass, combineTransposesOpOrder)
121{
123 /* Create graph:
124 * [input] [input2]
125 * || ||
126 * [Transpose 0] [Transpose1]
127 * || ||
128 * [Transpose 2] [Transpose 3]
129 * \\ //
130 * [Add]
131 */
132 mir::TensorType input_type{mir::DataType::FLOAT32, {1, 2, 3}};
133 Operation *in1 = g.create<ops::InputOp>(input_type);
134 Operation *in2 = g.create<ops::InputOp>(input_type);
135 Operation *tr0 = g.create<ops::TransposeOp>(in1->getOutput(0), vector<size_t>{1, 0, 2});
136 Operation *tr1 = g.create<ops::TransposeOp>(in2->getOutput(0), vector<size_t>{2, 1, 0});
137 Operation *tr2 = g.create<ops::TransposeOp>(tr0->getOutput(0), vector<size_t>{1, 0, 2});
138 Operation *tr3 = g.create<ops::TransposeOp>(tr1->getOutput(0), vector<size_t>{2, 1, 0});
139 Operation *elw = g.create<ops::AddOp>(tr2->getOutput(0), tr3->getOutput(0));
140 g.create<ops::OutputOp>(elw->getOutput(0));
141 int n1 = in1->getId();
142 int n2 = in2->getId();
144 pass.run(&g);
145 ASSERT_EQ(g.getOutputs()[0]->getInput(0)->getNode()->getType(), mir::Operation::Type::add);
146 // Order is preserved
147 ASSERT_EQ(n1, elw->getInput(0)->getNode()->getId());
148 ASSERT_EQ(n2, elw->getInput(1)->getNode()->getId());
149}
150} // unnamed namespace
Tensor transpose operation.
Definition TransposeOp.h:34
This pass combines sequential transposes and removes identity transposes if the combination results i...
PassData run(PassData data) override
run compiler pass
TEST(Shape, Base)
Definition Index.cpp:24
Definition Shape.h:28