ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FuseRsqrtPass.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "FuseRsqrtPass.h"
18
19#include "Check.h"
20
21#include "Dialect/IR/TFLNodes.h"
22
23namespace
24{
25
34locoex::TFLDiv *as_candidate(loco::Node *node)
35{
36 auto div = dynamic_cast<locoex::TFLDiv *>(node);
37 if (not div)
38 return nullptr;
39
40 // Cannot fuse Div with activation function
41 if (div->fusedActivationFunction() != locoex::FusedActFunc::NONE)
42 return nullptr;
43
44 auto const_one = dynamic_cast<locoex::TFLConst *>(div->x());
45 if (not const_one)
46 return nullptr;
47
48 const loco::DataType FLOAT32 = loco::DataType::FLOAT32;
49 // TODO Support other dtype
50 EXO_ASSERT(const_one->dtype() == FLOAT32, "Only support FLOAT32 now");
51 for (uint32_t i = 0; i < const_one->size<FLOAT32>(); ++i)
52 if (const_one->at<FLOAT32>(i) != 1.0f)
53 return nullptr;
54
55 auto sqrt = dynamic_cast<locoex::TFLSqrt *>(div->y());
56 if (not sqrt)
57 return nullptr;
58
59 return div;
60}
61
62void fuse_rsqrt(locoex::TFLDiv *div)
63{
64 auto sqrt = dynamic_cast<locoex::TFLSqrt *>(div->y());
65 EXO_ASSERT(sqrt, "sqrt should be valid at this point");
66
67 // TFLRsqrt to replace
68 auto rsqrt = div->graph()->nodes()->create<locoex::TFLRsqrt>();
69 rsqrt->x(sqrt->x());
70
71 // replace
72 loco::replace(div).with(rsqrt);
73}
74
75} // namespace
76
77namespace exo
78{
79
81{
82 bool changed = false;
83 for (auto node : loco::active_nodes(loco::output_nodes(g)))
84 {
85 if (auto div = as_candidate(node))
86 {
87 fuse_rsqrt(div);
88 changed = true;
89 }
90 }
91
92 return changed;
93}
94
95} // namespace exo
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:54
void with(Node *into) const
Definition Node.cpp:66
Class to build tensor data.
Definition TFLNodes.h:198
DIV in TensorFlow Lite.
Definition TFLNodes.h:280
loco::Node * x(void) const
Definition TFLNodes.h:453
#define EXO_ASSERT(condition, msg)
Definition Check.h:28
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
DataType
"scalar" value type
Definition DataType.h:27
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
bool run(loco::Graph *g) final
Run the pass.