63 :
64 bool matched()
65 {
66
68 _ifm = loco::must_cast<luci::CircleNode *>(_sqrt->x());
69
71
72
73 switch (_div->dtype())
74 {
75 case loco::DataType::S16:
82 _div_const->quantparam()->scale.at(0)));
83 break;
84
85 default:
86 return false;
87 }
88
89 return true;
90 }
91#undef CHECK_OR_FALSE
92
93public:
98};
99
100class FuseRsqrt final
101{
102public:
103 FuseRsqrt(const RsqrtPattern *p) : _p(p) {}
104
105public:
107
108private:
110
111private:
112 const RsqrtPattern *_p = nullptr;
113};
114
116{
117 assert(graph);
118
121 rsqrt->name(_p->_div->name() + "_rsqrt");
122
124
125 return rsqrt;
126}
127
128void FuseRsqrt::apply()
129{
130 auto graph = _p->_div->graph();
131
132 auto rsqrt = create_rsqrt(graph);
133
134
135 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
136 luci::get_origin(_p->_sqrt), luci::get_origin(_p->_div), luci::get_origin(_p->_div_const)};
137
139
141}
142
143}
144
145namespace
146{
147
149{
150 assert(div);
151
152 RsqrtPattern pattern(div);
153 if (pattern.matched())
154 {
155 FuseRsqrt fuse(&pattern);
156 fuse.apply();
157 return true;
158 }
159
160 return false;
161}
162
163}
164
166{
167
169{
170 bool changed = false;
171
173 {
175 if (not div)
176 continue;
177
178 if (fuse_rsqrt(div))
179 changed = true;
180 }
181
182 return changed;
183}
184
185}
void with(Node *into) const
Class to build tensor data.
loco::Node * x(void) const
#define CHECK_OR_FALSE(condition)
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Subst< SubstQualifier::Default > replace(Node *node)
void copy_quantparam(const luci::CircleNode *src, luci::CircleNode *dst)
copy CircleQuantParam of src to dst
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
bool run(loco::Graph *g) final
Run the pass.