ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FuseInstanceNormPass.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18
19#include "Dialect/IR/TFLNodes.h"
21
23
24#include <cassert>
25#include <set>
26
27// Helper to find commutative node's arguments
28namespace
29{
30
52template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final
53{
54public:
55 NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2)
56 {
57 // DO NOTHING
58 }
59
69 template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node);
70
71private:
72 ARG_TYPE_1 **_arg_1;
73 ARG_TYPE_2 **_arg_2;
74};
75
76template <class ARG_TYPE_1, class ARG_TYPE_2>
77inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
78{
79 return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2};
80}
81
82template <class ARG_TYPE_1, class ARG_TYPE_2>
83template <class COMM_NODE>
84bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node)
85{
86 // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2
87 {
88 auto x = dynamic_cast<ARG_TYPE_1 *>(node->x());
89 auto y = dynamic_cast<ARG_TYPE_2 *>(node->y());
90
91 if (x && y)
92 {
93 *_arg_1 = x;
94 *_arg_2 = y;
95 return true;
96 }
97 }
98
99 // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1
100 {
101 auto x = dynamic_cast<ARG_TYPE_2 *>(node->x());
102 auto y = dynamic_cast<ARG_TYPE_1 *>(node->y());
103
104 if (x && y)
105 {
106 *_arg_1 = y;
107 *_arg_2 = x;
108 return true;
109 }
110 }
111
112 return false;
113}
114
115} // namespace
116
117// Helper to check detail
118namespace
119{
120
122bool is_1D_with_dummy_dim(locoex::TFLConst *node, uint32_t depth)
123{
124 auto rank = node->rank();
125 uint32_t axis;
126 for (axis = 0; axis < rank - 1; ++axis)
127 {
128 if (node->dim(axis).value() != 1)
129 return false;
130 }
131 return node->dim(axis).value() == depth;
132}
133
134bool is_instance_mean(locoex::TFLMean *mean)
135{
136 //
137 // CHECK 1) input is rank 4
138 //
139 auto input = mean->input();
140 if (not loco::shape_known(input))
141 return false;
142 auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
143 if (input_shape.rank() != 4)
144 return false;
145
146 //
147 // CHECK 2) 'reduction indices' is TFLConst of value [1,2], that is HW of NHWC
148 //
149 // TODO Support equivalent case, like [-3,-2]
150 // TODO Support non-Const case?
151 // TODO What if input is NCHW format in Circle?
152 auto red_indices = dynamic_cast<locoex::TFLConst *>(mean->reduction_indices());
153 if (not red_indices)
154 return false;
155 if (red_indices->rank() != 1)
156 return false;
157 std::set<int32_t> red_indices_set;
158 {
159 // TODO Currently only support S32, support other types
160 assert(red_indices->dtype() == loco::DataType::S32);
161 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
162 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
163 }
164 if (red_indices_set.size() != 2)
165 return false;
166 if (red_indices_set.find(1) == red_indices_set.end())
167 return false;
168 if (red_indices_set.find(2) == red_indices_set.end())
169 return false;
170
171 //
172 // CHECK 3) keep_dims == true (?)
173 //
174 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
175 // TODO Check this fact, and if true, return true regardless of keep_dims
176 return mean->keep_dims();
177}
178
179} // namespace
180
181// Helper to fuse Instance Norm
182namespace
183{
184
232class InstanceNormPattern final
233{
234public:
235 InstanceNormPattern(locoex::TFLAdd *candidate)
236 {
237 assert(candidate);
238 add_as_terminal = candidate;
239 }
240
241public:
242 bool matched();
243 bool matched() const { return _matched; }
244
245public:
246 // Context
247 loco::Node *ifm = nullptr;
248 locoex::TFLMean *mean_of_ifm = nullptr;
249 locoex::TFLSquaredDifference *sqdiff = nullptr;
250 locoex::TFLMean *mean_as_variance = nullptr;
251 locoex::TFLConst *const_as_epsilon = nullptr;
252 locoex::TFLAdd *add_as_variance = nullptr;
253 locoex::TFLRsqrt *rsqrt = nullptr;
254 locoex::TFLConst *const_as_gamma = nullptr;
255 locoex::TFLMul *mul_gamma = nullptr;
256 locoex::TFLMul *mul_as_scaled_ifm = nullptr;
257 locoex::TFLMul *mul_as_scaled_mean = nullptr;
258 locoex::TFLConst *const_as_beta = nullptr;
259 locoex::TFLSub *sub = nullptr;
260 locoex::TFLAdd *add_as_terminal = nullptr;
261
262private:
263 bool _matched = false;
264};
265
266bool InstanceNormPattern::matched()
267{
268 if (_matched)
269 return true;
270
271#define CHECK_OR_FALSE(condition) \
272 if (not(condition)) \
273 return false;
274
275 // Check order is DFS
276
277 CHECK_OR_FALSE(fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
278 CHECK_OR_FALSE(fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
279
281 auto ifm_shape = loco::shape_get(ifm);
282 CHECK_OR_FALSE(ifm_shape.domain() == loco::Domain::Tensor);
283 auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>();
284 CHECK_OR_FALSE(ifm_tensor_shape.rank() == 4);
285 uint32_t ifm_channel_depth = ifm_tensor_shape.dim(3).value();
286
287 CHECK_OR_FALSE(fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
288 CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
289
290 add_as_variance = dynamic_cast<locoex::TFLAdd *>(rsqrt->x());
291 CHECK_OR_FALSE(add_as_variance);
292
294 fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
295
296 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
297 // TODO Support regarding broadcast
298 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
299
300 CHECK_OR_FALSE(is_instance_mean(mean_as_variance));
301 sqdiff = dynamic_cast<locoex::TFLSquaredDifference *>(mean_as_variance->input());
302 CHECK_OR_FALSE(sqdiff);
303
304 loco::Node *ifm_should_be = nullptr;
305 CHECK_OR_FALSE(fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
306 CHECK_OR_FALSE(ifm == ifm_should_be);
307 CHECK_OR_FALSE(is_instance_mean(mean_of_ifm));
308 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
309
310 const_as_beta = dynamic_cast<locoex::TFLConst *>(sub->x());
311 CHECK_OR_FALSE(const_as_beta);
312 CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
313
314 mul_as_scaled_mean = dynamic_cast<locoex::TFLMul *>(sub->y());
315 CHECK_OR_FALSE(mul_as_scaled_mean);
316
317 locoex::TFLMul *mul_gamma_should_be = nullptr;
318 locoex::TFLMean *mean_of_ifm_should_be = nullptr;
319 CHECK_OR_FALSE(fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
320 .with_commutative_args_of(mul_as_scaled_mean));
321 CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
322 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
323#undef CHECK_OR_FALSE
324 _matched = true;
325 return true;
326}
327
343void fuse_instance_norm(const InstanceNormPattern &p)
344{
345 assert(p.matched());
346
347 auto graph = p.add_as_terminal->graph();
348
349 // Make reshape for gamma & beta
350 auto reshape_gamma = graph->nodes()->create<locoex::TFLReshape>();
351 auto reshape_beta = graph->nodes()->create<locoex::TFLReshape>();
352 {
353 auto ifm_shape = loco::shape_get(p.ifm).as<loco::TensorShape>();
354 uint32_t ifm_channel_depth = ifm_shape.dim(3).value();
355
356 int32_t new_shape[1] = {static_cast<int32_t>(ifm_channel_depth)};
357
358 reshape_gamma->tensor(p.const_as_gamma);
359 reshape_beta->tensor(p.const_as_beta);
360
361 locoex::set_new_shape(reshape_gamma, new_shape, 1);
362 locoex::set_new_shape(reshape_beta, new_shape, 1);
363 }
364
365 // Make Instance Norm to replace
366 auto instance_norm = graph->nodes()->create<locoex::CircleInstanceNorm>();
367 instance_norm->input(p.ifm);
368 instance_norm->gamma(reshape_gamma);
369 instance_norm->beta(reshape_beta);
370 float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
371 instance_norm->epsilon(epsilon);
372 instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction());
373
374 replace(p.add_as_terminal).with(instance_norm);
375}
376
377} // namespace
378
379namespace exo
380{
381
383{
384 bool changed = false;
385 for (auto node : loco::active_nodes(loco::output_nodes(g)))
386 {
387 auto add = dynamic_cast<locoex::TFLAdd *>(node);
388 if (not add)
389 continue;
390
391 InstanceNormPattern pattern(add);
392 if (not pattern.matched())
393 continue;
394
395 fuse_instance_norm(pattern);
396 changed = true;
397 }
398
399 return changed;
400}
401
402} // namespace exo
uint32_t value(void) const
Return the value.
Definition Dimension.h:51
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:54
ShapeType as(void) const
void with(Node *into) const
Definition Node.cpp:66
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
INSTANCE_NORM in circle.
Definition CircleNodes.h:58
loco::Node * input(void) const
Definition CircleNodes.h:61
ADD in TensorFlow Lite.
Definition TFLNodes.h:116
Class to build tensor data.
Definition TFLNodes.h:198
loco::Node * input(void) const
Definition TFLNodes.h:356
bool keep_dims(void) const
Definition TFLNodes.h:363
loco::Node * reduction_indices(void) const
Definition TFLNodes.h:359
MUL in TensorFlow Lite.
Definition TFLNodes.h:375
loco::Node * tensor(void) const
Definition TFLNodes.h:410
SUB in TensorFlow Lite.
Definition TFLNodes.h:488
#define CHECK_OR_FALSE(condition)
bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
bool shape_known(const Node *node)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
NodeShape shape_get(const Node *node)
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
void set_new_shape(locoex::TFLReshape *node, int32_t *base, uint32_t size)
Set both TFLReshape's 2nd input as TFLConst, and newShape attribute with same value.
Definition TFLNodes.cpp:67
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
Definition NodeFiller.h:72
bool run(loco::Graph *g) final
Run the pass.