ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FuseArithmeticOps.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
19#include "mir/ops/AddOp.h"
20#include "mir/ops/ConstantOp.h"
21#include "mir/ops/Conv2DOp.h"
22#include "mir/ops/MulOp.h"
23#include "mir/Graph.h"
24#include "mir/Tensor.h"
25#include "mir/Index.h"
26#include "mir/TensorVariant.h"
27#include "mir/ShapeRange.h"
28
29#include <algorithm>
30
31namespace nnc
32{
33
34namespace
35{
36
37using namespace mir;
38using namespace std;
39using namespace opt_util;
40
41using OpType = Operation::Type;
42using Edge = pair<Operation *, Operation *>;
43
48ops::ConstantOp *getSecondInputAsConst(Operation *op)
49{
50 assert(op->getType() == OpType::add || op->getType() == OpType::mul ||
51 op->getType() == OpType::conv2D);
52 return dynamic_cast<ops::ConstantOp *>(op->getInput(1)->getNode());
53}
54
55// This function finds successive operations of given types, with ConstantOp as second input
56vector<Edge> findSuccessiveOpsWithConstWeights(Graph *g, OpType first_op_type,
57 OpType second_op_type)
58{
59 vector<Edge> matches;
60 unordered_set<Operation *> matched_nodes;
61 for (auto *first_op : g->getNodes())
62 {
63 if (first_op->getType() == first_op_type && getSecondInputAsConst(first_op))
64 {
65 for (auto &out : first_op->getOutputs())
66 {
67 for (Operation::Use use : out.getUses())
68 {
69 Operation *second_op = use.getNode();
70 if (second_op->getType() == second_op_type && getSecondInputAsConst(second_op))
71 {
76 if (matched_nodes.find(first_op) == matched_nodes.end() &&
77 matched_nodes.find(second_op) == matched_nodes.end())
78 {
79 matched_nodes.emplace(first_op);
80 matched_nodes.emplace(second_op);
81 matches.emplace_back(first_op, second_op);
82 }
83 }
84 }
85 }
86 }
87 }
88 return matches;
89}
90
95Operation *mergeConstantOps(Graph *g, const ops::ConstantOp *const1_op,
96 const ops::ConstantOp *const2_op, OpType merge_type)
97{
98 const auto &const1_val = const1_op->getValue();
99 const auto &const2_val = const2_op->getValue();
100 assert(const1_val.getShape().rank() >= const2_val.getShape().rank());
101 assert(const2_val.getShape().rank() == 1);
102 assert(const1_val.getShape().dim(0) == const2_val.getShape().dim(0));
103
104 // Create and fill TensorVariant for new ConstantOp
105 TensorVariant new_const_val(DataType::FLOAT32, const1_val.getShape());
106 Tensor<float> const1_accessor(const1_val);
107 Tensor<float> const2_accessor(const2_val);
108 Tensor<float> new_const_accessor(new_const_val);
109 ShapeRange const1_range(const1_val.getShape());
110 for (auto &idx : const1_range)
111 {
112 float operand1 = const1_accessor.at(idx);
117 float operand2 = const2_accessor.at(Index{idx.at(0)});
118 switch (merge_type)
119 {
120 case OpType::mul:
121 new_const_accessor.at(idx) = operand1 * operand2;
122 break;
123 case OpType::add:
124 new_const_accessor.at(idx) = operand1 + operand2;
125 break;
126 default:
127 assert(false && "only 'mul' and 'add' constants merge types supported");
128 }
129 }
130
131 return g->create<ops::ConstantOp>(new_const_val);
132}
133
134// TODO: support 'DepthwiseConv'->'Mul'
158bool fuseSuccessiveOps(Graph *g)
159{
160 // Find all successive ops
161 vector<Edge> successive_ops;
162 auto mul_mul_vec = findSuccessiveOpsWithConstWeights(g, OpType::mul, OpType::mul);
163 successive_ops.insert(successive_ops.end(), mul_mul_vec.begin(), mul_mul_vec.end());
164 auto add_add_vec = findSuccessiveOpsWithConstWeights(g, OpType::add, OpType::add);
165 successive_ops.insert(successive_ops.end(), add_add_vec.begin(), add_add_vec.end());
166 auto conv_mul_vec = findSuccessiveOpsWithConstWeights(g, OpType::conv2D, OpType::mul);
167 successive_ops.insert(successive_ops.end(), conv_mul_vec.begin(), conv_mul_vec.end());
168
169 for (auto &edge : successive_ops)
170 {
171 auto const1_op = getSecondInputAsConst(edge.first);
172 auto const2_op = getSecondInputAsConst(edge.second);
173 assert(const1_op && const2_op);
174
175 // Create new constant operation and copy first successive operation
176 auto new_const_op = mergeConstantOps(g, const1_op, const2_op, edge.second->getType());
177 auto first_op_input = edge.first->getInput(0);
178 auto new_op = g->copyOpWithInputs(edge.first, {first_op_input, new_const_op->getOutput(0)});
179
180 // Replace second successive operation with new one and remove old nodes
181 g->replaceNode(edge.second, new_op);
182 removeNodeIfUnused(g, edge.first);
183 removeNodeIfUnused(g, const1_op);
184 removeNodeIfUnused(g, const2_op);
185 }
186
187 // If there is no successive operations to fuse - graph wasn't changed
188 return !successive_ops.empty();
189}
190
203bool sinkAddThroughMul(Graph *g)
204{
205 auto add_mul_edges = findSuccessiveOpsWithConstWeights(g, OpType::add, OpType::mul);
206
207 for (auto &edge : add_mul_edges)
208 {
209 auto old_add_op = edge.first;
210 auto old_mul_op = edge.second;
211 auto old_add_const_op = getSecondInputAsConst(old_add_op);
212 auto ols_mul_const_op = getSecondInputAsConst(old_mul_op);
213 assert(old_add_const_op && ols_mul_const_op);
214
215 // Create new operations
216 auto old_add_input = old_add_op->getInput(0);
217 auto new_mul_op =
218 g->copyOpWithInputs(old_mul_op, {old_add_input, ols_mul_const_op->getOutput(0)});
219 auto new_add_const_op = mergeConstantOps(g, old_add_const_op, ols_mul_const_op, OpType::mul);
220 auto new_add_op =
221 g->copyOpWithInputs(old_add_op, {new_mul_op->getOutput(0), new_add_const_op->getOutput(0)});
222
223 // Replace old mul with new add and remove old nodes
224 g->replaceNode(old_mul_op, new_add_op);
225 removeNodeIfUnused(g, old_add_op);
226 removeNodeIfUnused(g, old_add_const_op);
227 }
228
229 // If there is no add-mul edges - graph wasn't changed
230 return !add_mul_edges.empty();
231}
232
233} // unnamed namespace
234
236{
237 auto g = static_cast<Graph *>(data);
238
239 bool graph_changed = true;
240 while (graph_changed)
241 {
242 graph_changed = false;
243 graph_changed |= fuseSuccessiveOps(g);
244 graph_changed |= sinkAddThroughMul(g);
245 }
246
247 return g;
248}
249
250} // namespace nnc
int32_t & at(int32_t axis)
return position on given axis
Definition Index.h:64
const TensorVariant & getValue() const
Definition ConstantOp.h:36
PassData run(PassData data) override
run compiler pass
class that encapsulate value returned and taken by pass
Definition PassData.h:30
void removeNodeIfUnused(mir::Graph *g, mir::Operation *op)