39using namespace opt_util;
41using OpType = Operation::Type;
42using Edge = pair<Operation *, Operation *>;
50 assert(op->getType() == OpType::add || op->getType() == OpType::mul ||
51 op->getType() == OpType::conv2D);
56vector<Edge> findSuccessiveOpsWithConstWeights(
Graph *g, OpType first_op_type,
57 OpType second_op_type)
60 unordered_set<Operation *> matched_nodes;
61 for (
auto *first_op :
g->getNodes())
63 if (first_op->getType() == first_op_type && getSecondInputAsConst(first_op))
65 for (
auto &out : first_op->getOutputs())
67 for (Operation::Use use : out.getUses())
70 if (second_op->getType() == second_op_type && getSecondInputAsConst(second_op))
76 if (matched_nodes.find(first_op) == matched_nodes.end() &&
77 matched_nodes.find(second_op) == matched_nodes.end())
79 matched_nodes.emplace(first_op);
80 matched_nodes.emplace(second_op);
81 matches.emplace_back(first_op, second_op);
98 const auto &const1_val = const1_op->
getValue();
99 const auto &const2_val = const2_op->
getValue();
100 assert(const1_val.getShape().rank() >= const2_val.getShape().rank());
101 assert(const2_val.getShape().rank() == 1);
102 assert(const1_val.getShape().dim(0) == const2_val.getShape().dim(0));
105 TensorVariant new_const_val(DataType::FLOAT32, const1_val.getShape());
109 ShapeRange const1_range(const1_val.getShape());
110 for (
auto &idx : const1_range)
112 float operand1 = const1_accessor.at(idx);
117 float operand2 = const2_accessor.at(
Index{idx.
at(0)});
121 new_const_accessor.at(idx) = operand1 * operand2;
124 new_const_accessor.at(idx) = operand1 + operand2;
127 assert(
false &&
"only 'mul' and 'add' constants merge types supported");
158bool fuseSuccessiveOps(
Graph *g)
161 vector<Edge> successive_ops;
162 auto mul_mul_vec = findSuccessiveOpsWithConstWeights(g, OpType::mul, OpType::mul);
163 successive_ops.insert(successive_ops.end(), mul_mul_vec.begin(), mul_mul_vec.end());
164 auto add_add_vec = findSuccessiveOpsWithConstWeights(g, OpType::add, OpType::add);
165 successive_ops.insert(successive_ops.end(), add_add_vec.begin(), add_add_vec.end());
166 auto conv_mul_vec = findSuccessiveOpsWithConstWeights(g, OpType::conv2D, OpType::mul);
167 successive_ops.insert(successive_ops.end(), conv_mul_vec.begin(), conv_mul_vec.end());
169 for (
auto &edge : successive_ops)
171 auto const1_op = getSecondInputAsConst(edge.first);
172 auto const2_op = getSecondInputAsConst(edge.second);
173 assert(const1_op && const2_op);
176 auto new_const_op = mergeConstantOps(g, const1_op, const2_op, edge.second->getType());
177 auto first_op_input = edge.first->getInput(0);
178 auto new_op =
g->copyOpWithInputs(edge.first, {first_op_input, new_const_op->getOutput(0)});
181 g->replaceNode(edge.second, new_op);
188 return !successive_ops.empty();
203bool sinkAddThroughMul(
Graph *g)
205 auto add_mul_edges = findSuccessiveOpsWithConstWeights(g, OpType::add, OpType::mul);
207 for (
auto &edge : add_mul_edges)
209 auto old_add_op = edge.first;
210 auto old_mul_op = edge.second;
211 auto old_add_const_op = getSecondInputAsConst(old_add_op);
212 auto ols_mul_const_op = getSecondInputAsConst(old_mul_op);
213 assert(old_add_const_op && ols_mul_const_op);
216 auto old_add_input = old_add_op->getInput(0);
218 g->copyOpWithInputs(old_mul_op, {old_add_input, ols_mul_const_op->getOutput(0)});
219 auto new_add_const_op = mergeConstantOps(g, old_add_const_op, ols_mul_const_op, OpType::mul);
221 g->copyOpWithInputs(old_add_op, {new_mul_op->getOutput(0), new_add_const_op->getOutput(0)});
224 g->replaceNode(old_mul_op, new_add_op);
230 return !add_mul_edges.empty();
237 auto g =
static_cast<Graph *
>(data);
239 bool graph_changed =
true;
240 while (graph_changed)
242 graph_changed =
false;
243 graph_changed |= fuseSuccessiveOps(g);
244 graph_changed |= sinkAddThroughMul(g);
int32_t & at(int32_t axis)
return position on given axis
const TensorVariant & getValue() const
PassData run(PassData data) override
run compiler pass
class that encapsulate value returned and taken by pass
void removeNodeIfUnused(mir::Graph *g, mir::Operation *op)