ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FuseRsqrtPass.cpp File Reference
#include "luci/Pass/FuseRsqrtPass.h"
#include "helpers/NodeFiller.h"
#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>
#include <cmath>
#include <cassert>

Go to the source code of this file.

Namespaces

namespace  luci
 

Macros

#define CHECK_OR_FALSE(condition)
 

Macro Definition Documentation

◆ CHECK_OR_FALSE

#define CHECK_OR_FALSE (   condition)
Value:
if (not(condition)) \
return false;

Definition at line 60 of file FuseRsqrtPass.cpp.

63 :
64 bool matched()
65 {
66 // Check pattern
67 CHECK_OR_FALSE(luci::fill(&_div_const, &_sqrt).with_args_of(_div));
68 _ifm = loco::must_cast<luci::CircleNode *>(_sqrt->x());
69
70 CHECK_OR_FALSE(_div->fusedActivationFunction() == luci::FusedActFunc::NONE);
71
72 // Check div_const = 1
73 switch (_div->dtype())
74 {
75 case loco::DataType::S16:
76 CHECK_OR_FALSE(_div_const->quantparam() != nullptr);
77 CHECK_OR_FALSE(_div_const->quantparam()->scale.size() == 1);
78 CHECK_OR_FALSE(_div_const->quantparam()->zerop.size() == 1);
79 CHECK_OR_FALSE(_div_const->quantparam()->zerop.at(0) == 0);
80 CHECK_OR_FALSE(_div_const->size<loco::DataType::S16>() == 1);
81 CHECK_OR_FALSE(same(1.0, _div_const->at<loco::DataType::S16>(0) *
82 _div_const->quantparam()->scale.at(0)));
83 break;
84 // TODO Support more dtypes
85 default:
86 return false;
87 }
88
89 return true;
90 }
91#undef CHECK_OR_FALSE
92
93public:
94 luci::CircleNode *_ifm = nullptr;
95 luci::CircleSqrt *_sqrt = nullptr;
96 luci::CircleDiv *_div = nullptr;
97 luci::CircleConst *_div_const = nullptr;
98};
99
100class FuseRsqrt final
101{
102public:
103 FuseRsqrt(const RsqrtPattern *p) : _p(p) {}
104
105public:
106 void apply(void);
107
108private:
109 luci::CircleRsqrt *create_rsqrt(loco::Graph *graph);
110
111private:
112 const RsqrtPattern *_p = nullptr;
113};
114
115luci::CircleRsqrt *FuseRsqrt::create_rsqrt(loco::Graph *graph)
116{
117 assert(graph);
118
119 auto rsqrt = graph->nodes()->create<luci::CircleRsqrt>();
120 rsqrt->x(_p->_ifm);
121 rsqrt->name(_p->_div->name() + "_rsqrt");
122
123 luci::copy_quantparam(_p->_div, rsqrt);
124
125 return rsqrt;
126}
127
128void FuseRsqrt::apply()
129{
130 auto graph = _p->_div->graph();
131
132 auto rsqrt = create_rsqrt(graph);
133
134 // set origin
135 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
136 luci::get_origin(_p->_sqrt), luci::get_origin(_p->_div), luci::get_origin(_p->_div_const)};
137
138 luci::add_origin(rsqrt, luci::composite_origin(origin_vec));
139
140 replace(_p->_div).with(rsqrt);
141}
142
143} // namespace
144
145namespace
146{
147
148bool fuse_rsqrt(luci::CircleDiv *div)
149{
150 assert(div);
151
152 RsqrtPattern pattern(div);
153 if (pattern.matched())
154 {
155 FuseRsqrt fuse(&pattern);
156 fuse.apply();
157 return true;
158 }
159
160 return false;
161}
162
163} // namespace
164
165namespace luci
166{
167
169{
170 bool changed = false;
171
172 for (auto node : loco::active_nodes(loco::output_nodes(g)))
173 {
174 auto div = dynamic_cast<luci::CircleDiv *>(node);
175 if (not div)
176 continue;
177
178 if (fuse_rsqrt(div))
179 changed = true;
180 }
181
182 return changed;
183}
184
185} // namespace luci
A neural network graph.
Definition Graph.h:161
void with(Node *into) const
Definition Node.cpp:66
Class to build tensor data.
Definition CircleConst.h:35
DIV in Circle.
Definition CircleDiv.h:37
RSQRT in Circle.
Definition CircleRsqrt.h:32
loco::Node * x(void) const
Definition CircleRsqrt.h:34
SQRT in Circle.
Definition CircleSqrt.h:32
#define CHECK_OR_FALSE(condition)
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
void copy_quantparam(const luci::CircleNode *src, luci::CircleNode *dst)
copy CircleQuantParam of src to dst
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
Definition NodeFiller.h:72
bool run(loco::Graph *g) final
Run the pass.