ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FuseRsqrtPass.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18#include "helpers/NodeFiller.h"
19
20#include <luci/IR/CircleNodes.h>
22
23#include <cmath>
24#include <cassert>
25
26namespace
27{
28
48// Float comparison
49bool same(float a, float b) { return fabs(a - b) < 1e-5; }
50
51class RsqrtPattern
52{
53public:
54 RsqrtPattern(luci::CircleDiv *candidate)
55 {
56 assert(candidate); // FIX_CALLER_UNLESS
57 _div = candidate;
58 }
59
60#define CHECK_OR_FALSE(condition) \
61 if (not(condition)) \
62 return false;
63
64public:
65 bool matched()
66 {
67 // Check pattern
68 CHECK_OR_FALSE(luci::fill(&_div_const, &_sqrt).with_args_of(_div));
69 _ifm = loco::must_cast<luci::CircleNode *>(_sqrt->x());
70
71 CHECK_OR_FALSE(_div->fusedActivationFunction() == luci::FusedActFunc::NONE);
72
73 // Check div_const = 1
74 switch (_div->dtype())
75 {
76 case loco::DataType::S16:
77 CHECK_OR_FALSE(_div_const->quantparam() != nullptr);
78 CHECK_OR_FALSE(_div_const->quantparam()->scale.size() == 1);
79 CHECK_OR_FALSE(_div_const->quantparam()->zerop.size() == 1);
80 CHECK_OR_FALSE(_div_const->quantparam()->zerop.at(0) == 0);
81 CHECK_OR_FALSE(_div_const->size<loco::DataType::S16>() == 1);
82 CHECK_OR_FALSE(same(1.0, _div_const->at<loco::DataType::S16>(0) *
83 _div_const->quantparam()->scale.at(0)));
84 break;
85 // TODO Support more dtypes
86 default:
87 return false;
88 }
89
90 return true;
91 }
92#undef CHECK_OR_FALSE
93
94public:
95 luci::CircleNode *_ifm = nullptr;
96 luci::CircleSqrt *_sqrt = nullptr;
97 luci::CircleDiv *_div = nullptr;
98 luci::CircleConst *_div_const = nullptr;
99};
100
101class FuseRsqrt final
102{
103public:
104 FuseRsqrt(const RsqrtPattern *p) : _p(p) {}
105
106public:
107 void apply(void);
108
109private:
110 luci::CircleRsqrt *create_rsqrt(loco::Graph *graph);
111
112private:
113 const RsqrtPattern *_p = nullptr;
114};
115
116luci::CircleRsqrt *FuseRsqrt::create_rsqrt(loco::Graph *graph)
117{
118 assert(graph);
119
120 auto rsqrt = graph->nodes()->create<luci::CircleRsqrt>();
121 rsqrt->x(_p->_ifm);
122 rsqrt->name(_p->_div->name() + "_rsqrt");
123
124 luci::copy_quantparam(_p->_div, rsqrt);
125
126 return rsqrt;
127}
128
129void FuseRsqrt::apply()
130{
131 auto graph = _p->_div->graph();
132
133 auto rsqrt = create_rsqrt(graph);
134
135 // set origin
136 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
137 luci::get_origin(_p->_sqrt), luci::get_origin(_p->_div), luci::get_origin(_p->_div_const)};
138
139 luci::add_origin(rsqrt, luci::composite_origin(origin_vec));
140
141 replace(_p->_div).with(rsqrt);
142}
143
144} // namespace
145
146namespace
147{
148
149bool fuse_rsqrt(luci::CircleDiv *div)
150{
151 assert(div);
152
153 RsqrtPattern pattern(div);
154 if (pattern.matched())
155 {
156 FuseRsqrt fuse(&pattern);
157 fuse.apply();
158 return true;
159 }
160
161 return false;
162}
163
164} // namespace
165
166namespace luci
167{
168
170{
171 bool changed = false;
172
173 for (auto node : loco::active_nodes(loco::output_nodes(g)))
174 {
175 auto div = dynamic_cast<luci::CircleDiv *>(node);
176 if (not div)
177 continue;
178
179 if (fuse_rsqrt(div))
180 changed = true;
181 }
182
183 return changed;
184}
185
186} // namespace luci
A neural network graph.
Definition Graph.h:161
void with(Node *into) const
Definition Node.cpp:66
Class to build tensor data.
Definition CircleConst.h:35
DIV in Circle.
Definition CircleDiv.h:37
RSQRT in Circle.
Definition CircleRsqrt.h:32
loco::Node * x(void) const
Definition CircleRsqrt.h:34
SQRT in Circle.
Definition CircleSqrt.h:32
#define CHECK_OR_FALSE(condition)
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
void copy_quantparam(const luci::CircleNode *src, luci::CircleNode *dst)
copy CircleQuantParam of src to dst
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
Definition NodeFiller.h:72
bool run(loco::Graph *g) final
Run the pass.