ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FuseInstanceNormPass.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18#include "helpers/NodeFiller.h"
20
21#include <luci/IR/CircleNodes.h>
22
25
26#include <cassert>
27#include <set>
28#include <optional>
29
30// Helper to check detail
31
32#define CHECK_OR_FALSE(condition) \
33 if (not(condition)) \
34 return false;
35
38bool is_unsqueezed_1D(luci::CircleConst *node, uint32_t depth)
39{
40 const auto rank = node->rank();
41 std::optional<uint32_t> depth_axis;
42 for (uint32_t axis = 0; axis < rank; ++axis)
43 {
44 if (node->dim(axis).value() != 1)
45 {
46 // only one axis can be other than 1
47 if (depth_axis.has_value())
48 {
49 return false;
50 }
51 depth_axis = axis;
52 }
53 }
54 if (!depth_axis.has_value())
55 {
56 return false;
57 }
58 return node->dim(depth_axis.value()).value() == depth;
59}
60
65 luci::CircleReshape *terminal_reshape)
66{
67 auto const begin_reshape_ifm = dynamic_cast<luci::CircleNode *>(begin_reshape->tensor());
68 CHECK_OR_FALSE(begin_reshape_ifm);
69
70 // check last axis
71 CHECK_OR_FALSE((begin_reshape_ifm->rank() + 1) == begin_reshape->rank());
72
73 // check unchanged part of begin_shape
74 for (uint32_t axis = 0; axis < begin_reshape_ifm->rank(); ++axis)
75 {
76 // skip dynamic cases
77 CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).known() && begin_reshape->dim(axis).known());
78 CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).value() == begin_reshape->dim(axis).value());
79 }
80 // check last axis
81 CHECK_OR_FALSE(begin_reshape->dim(begin_reshape->rank() - 1) == 1);
82
83 auto const terminal_reshape_ifm = dynamic_cast<luci::CircleNode *>(terminal_reshape->tensor());
84 CHECK_OR_FALSE(terminal_reshape_ifm);
85
86 CHECK_OR_FALSE(terminal_reshape_ifm->rank() == terminal_reshape->rank() + 1);
87
88 // check last axis
89 CHECK_OR_FALSE(terminal_reshape_ifm->dim(begin_reshape->rank() - 1) == 1);
90
91 // check unchanged part of terminal_reshape
92 for (uint32_t axis = 0; axis < terminal_reshape->rank(); ++axis)
93 {
94 // skip dynamic cases
95 CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).known() && terminal_reshape->dim(axis).known());
96 CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).value() == terminal_reshape->dim(axis).value());
97 }
98
99 return true;
100}
101
103{
104 //
105 // CHECK 1) input is rank 4
106 //
107 auto input = loco::must_cast<luci::CircleNode *>(mean->input());
108 if (input->shape_status() != luci::ShapeStatus::VALID)
109 return false;
110 if (input->rank() != 4)
111 return false;
112
113 //
114 // CHECK 2) 'reduction indices' is CircleConst of value [1,2], that is HW of NHWC
115 //
116 // TODO Support equivalent case, like [-3,-2]
117 // TODO Support non-Const case?
118 // TODO What if input is NCHW format in Circle?
119 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
120 if (not red_indices)
121 return false;
122 if (red_indices->rank() != 1)
123 return false;
124 std::set<int32_t> red_indices_set;
125 {
126 // TODO Currently only support S32, support other types
127 assert(red_indices->dtype() == loco::DataType::S32);
128 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
129 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
130 }
131 if (red_indices_set.size() != 2)
132 return false;
133 if (red_indices_set.find(1) == red_indices_set.end())
134 return false;
135 if (red_indices_set.find(2) == red_indices_set.end())
136 return false;
137
138 //
139 // CHECK 3) keep_dims == true (?)
140 //
141 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
142 // TODO Check this fact, and if true, return true regardless of keep_dims
143 return mean->keep_dims();
144}
145
147{
148 //
149 // CHECK 1) input is rank 3
150 //
151 auto input = loco::must_cast<luci::CircleNode *>(mean->input());
152 if (input->shape_status() != luci::ShapeStatus::VALID)
153 return false;
154 if (input->rank() != 3)
155 return false;
156
157 //
158 // CHECK 2) 'reduction indices' is CircleConst of value [2], that is last dim of rank 3
159 //
160 // TODO Support non-Const case?
161 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
162 if (not red_indices)
163 return false;
164 if (red_indices->rank() != 1)
165 return false;
166 std::set<int32_t> red_indices_set;
167 {
168 // TODO Currently only support S32, support other types
169 assert(red_indices->dtype() == loco::DataType::S32);
170 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
171 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
172 }
173 if (red_indices_set.size() != 1)
174 return false;
175 if (red_indices_set.find(2) == red_indices_set.end())
176 return false;
177
178 //
179 // CHECK 3) keep_dims == true (?)
180 //
181 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
182 // TODO Check this fact, and if true, return true regardless of keep_dims
183 return mean->keep_dims();
184}
185
187bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
188{
189 if (node->rank() != 1)
190 return false;
191
192 if (node->dim(0).value() != channel_size)
193 return false;
194
195 if (node->dtype() != loco::DataType::FLOAT32)
196 return false;
197
198 if (node->size<loco::DataType::FLOAT32>() != channel_size)
199 return false;
200
201 return true;
202}
203
204// Helper to fuse Instance Norm
205namespace
206{
207
483class InstanceNormPattern final
484{
485public:
486 enum PatternVersion
487 {
488 Version_Unknown,
489 Version_1,
490 Version_2,
491 Version_3,
492 Version_4,
493 Version_5,
494 Version_6, // For only 3D I/O
495 Version_7,
496 };
497
498 InstanceNormPattern(luci::CircleAdd *candidate, PatternVersion pv)
499 {
500 assert(candidate);
501 add_as_terminal = candidate;
502 _pv = pv;
503 }
504
505 InstanceNormPattern(luci::CircleDiv *candidate, PatternVersion pv)
506 {
507 assert(candidate);
508 div = candidate;
509 _pv = pv;
510 }
511
512 InstanceNormPattern(luci::CircleReshape *candidate, PatternVersion pv)
513 {
514 assert(candidate);
515 reshape_as_terminal = candidate;
516 _pv = pv;
517 }
518
519private:
520 bool condition_common_1_5(uint32_t ifm_channel_depth);
521 bool condition_common_3_4();
522
523private:
524 template <enum PatternVersion> bool match();
525
526public:
527 bool matched();
528 bool matched() const { return _matched; }
529
530 PatternVersion version() const { return _pv; }
531
532public:
533 // Context
534 loco::Node *ifm = nullptr;
535 luci::CircleReshape *reshape_of_ifm = nullptr;
536 luci::CircleMean *mean_of_ifm = nullptr;
537 luci::CircleMean *mean_of_ifm_2 = nullptr;
538 luci::CircleMean *mean_of_reshape = nullptr;
539 luci::CircleSquaredDifference *sqdiff = nullptr;
540 luci::CircleSquare *square = nullptr;
541 luci::CircleMean *mean_as_variance = nullptr;
542 luci::CircleConst *const_as_epsilon = nullptr;
543 luci::CircleAdd *add_as_variance = nullptr;
544 luci::CircleAdd *add_neg_mul = nullptr;
545 luci::CircleRsqrt *rsqrt = nullptr;
546 luci::CircleConst *const_as_gamma = nullptr;
547 luci::CircleMul *mul_gamma = nullptr;
548 luci::CircleMul *mul_as_scaled_ifm = nullptr;
549 luci::CircleMul *mul_as_scaled_mean = nullptr;
550 luci::CircleMul *mul_as_scaled_reshape = nullptr;
551 luci::CircleConst *const_as_beta = nullptr;
552 luci::CircleSub *sub = nullptr;
553 luci::CircleSub *sub_2 = nullptr;
554 luci::CircleAdd *add_as_terminal = nullptr;
555 luci::CirclePow *pow = nullptr;
556 luci::CircleSqrt *sqrt = nullptr;
557 luci::CircleDiv *div = nullptr;
558 luci::CircleConst *reshape_terminal_target_shape = nullptr;
559 luci::CircleReshape *reshape_as_terminal = nullptr;
560 luci::CircleNeg *neg_mean = nullptr;
561
562private:
563 bool _matched = false;
564 PatternVersion _pv;
565};
566
567bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
568{
569 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
570 CHECK_OR_FALSE(add_as_variance);
571
573 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
574
575 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
576 // TODO Support regarding broadcast
577 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
578
579 CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
580
581 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
582 CHECK_OR_FALSE(sqdiff);
583
584 loco::Node *ifm_should_be = nullptr;
585 CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
586 CHECK_OR_FALSE(ifm == ifm_should_be);
588 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
589
590 const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
591 CHECK_OR_FALSE(const_as_beta);
592 CHECK_OR_FALSE(is_unsqueezed_1D(const_as_beta, ifm_channel_depth));
593
594 return true;
595}
596
597bool InstanceNormPattern::condition_common_3_4()
598{
599 // check left sub
600 ifm = sub->x();
601 CHECK_OR_FALSE(ifm);
602
604 CHECK_OR_FALSE(ifm_node->rank() == 4);
605 CHECK_OR_FALSE(ifm_node->dim(3).known());
606
607 mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
608 CHECK_OR_FALSE(mean_of_ifm);
609 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
610
611 // continue search from add_as_variance
612 CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance));
613 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
614 // TODO Support regarding broadcast
615 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
616
617 mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x());
618 CHECK_OR_FALSE(mean_as_variance);
619
620 square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
621 CHECK_OR_FALSE(square);
622
623 sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
624 CHECK_OR_FALSE(sub_2);
625 CHECK_OR_FALSE(ifm == sub_2->x());
626
627 mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y());
628 CHECK_OR_FALSE(mean_of_ifm_2);
629 CHECK_OR_FALSE(ifm == mean_of_ifm_2->input());
630
631 loco::Node *ifm_should_be = nullptr;
632 luci::CircleMean *mean_of_ifm_2_should_be = nullptr;
634 luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
635 CHECK_OR_FALSE(ifm == ifm_should_be);
636 CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be);
637
638 return true;
639}
640
641template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
642{
643 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
644 CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
645
646 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
647 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
648 CHECK_OR_FALSE(ifm_circle->rank() == 4);
649 CHECK_OR_FALSE(ifm_circle->dim(3).known());
650 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
651
652 CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
653
654 CHECK_OR_FALSE(is_unsqueezed_1D(const_as_gamma, ifm_channel_depth));
655
656 CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
657
658 luci::CircleMul *mul_gamma_should_be = nullptr;
659 luci::CircleMean *mean_of_ifm_should_be = nullptr;
660
661 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
662 CHECK_OR_FALSE(mul_as_scaled_mean);
663 CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
664 .with_commutative_args_of(mul_as_scaled_mean));
665 CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
666 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
667
668 _matched = true;
669 return true;
670}
671
672template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_2>()
673{
674 CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
675 CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
676
677 sub = dynamic_cast<luci::CircleSub *>(div->x());
678 CHECK_OR_FALSE(sub);
679
680 ifm = sub->x();
681 CHECK_OR_FALSE(ifm);
682
684 CHECK_OR_FALSE(ifm_node->rank() == 4);
685 CHECK_OR_FALSE(ifm_node->dim(3).known());
686 uint32_t ifm_channel_depth = ifm_node->dim(3).value();
687
688 mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
689 CHECK_OR_FALSE(mean_of_ifm);
690
691 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
692
693 pow = dynamic_cast<luci::CirclePow *>(div->y());
694 CHECK_OR_FALSE(pow);
695
696 add_as_variance = dynamic_cast<luci::CircleAdd *>(pow->x());
697 CHECK_OR_FALSE(add_as_variance);
698
699 luci::CircleConst *zero_point_five = dynamic_cast<luci::CircleConst *>(pow->y());
700 CHECK_OR_FALSE(zero_point_five);
701 CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
702 // TODO Support regarding broadcast
703 CHECK_OR_FALSE(zero_point_five->size<loco::DataType::FLOAT32>() == 1);
704 CHECK_OR_FALSE(zero_point_five->at<loco::DataType::FLOAT32>(0) == 0.5);
705
707 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
708 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
709 // TODO Support regarding broadcast
710 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
711
712 CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
713
714 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
715 CHECK_OR_FALSE(sqdiff);
716
717 loco::Node *ifm_should_be = nullptr;
718 luci::CircleMean *mean_of_ifm_should_be = nullptr;
720 luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
721 CHECK_OR_FALSE(ifm == ifm_should_be);
722 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
723
724 // Check for channel size
725 CHECK_OR_FALSE(is_1D_float32_const(const_as_gamma, ifm_channel_depth));
726 CHECK_OR_FALSE(is_1D_float32_const(const_as_beta, ifm_channel_depth));
727
728 _matched = true;
729 return true;
730}
731
732template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_3>()
733{
734 CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
735 CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
736 CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
737
738 CHECK_OR_FALSE(condition_common_3_4());
739
740 _matched = true;
741 return true;
742}
743
744luci::CircleConst *make_const_one(loco::Graph *graph, float value)
745{
746 auto const_one = graph->nodes()->create<luci::CircleConst>();
747 const_one->dtype(loco::DataType::FLOAT32);
748 const_one->rank(1);
749 const_one->size<loco::DataType::FLOAT32>(1);
750 const_one->at<loco::DataType::FLOAT32>(0) = value;
751 return const_one;
752}
753
754template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_4>()
755{
756 CHECK_OR_FALSE(div);
757 CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
758
759 CHECK_OR_FALSE(condition_common_3_4());
760
761 assert(const_as_gamma == nullptr);
762 assert(const_as_beta == nullptr);
763 assert(mul_gamma == nullptr);
764 assert(add_as_terminal == nullptr);
765
766 // create 1.0 gamma and 0.0 beta
767 auto graph = div->graph();
768 const_as_gamma = make_const_one(graph, 1.0f);
769 const_as_beta = make_const_one(graph, 0.0f);
770 const_as_gamma->name(div->name() + "/gamma");
771 const_as_beta->name(div->name() + "/beta");
772
773 _matched = true;
774 return true;
775}
776
777template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
778{
779 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
780 CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
781
782 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
783 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
784 CHECK_OR_FALSE(ifm_circle->rank() == 4);
785 CHECK_OR_FALSE(ifm_circle->dim(3).known());
786 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
787
788 CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
789
790 luci::CircleRsqrt *rsqrt_should_be = nullptr;
791 luci::CircleMean *mean_of_ifm_should_be = nullptr;
792
793 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
794 CHECK_OR_FALSE(mul_as_scaled_mean);
795 CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
796 .with_commutative_args_of(mul_as_scaled_mean));
797 CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
798 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
799
800 // mul_gamma is absent
801 // const_as_gamma assume to be 1.0
802 auto graph = add_as_terminal->graph();
803 const_as_gamma = make_const_one(graph, 1.0f);
804 const_as_gamma->name(add_as_terminal->name() + "/gamma");
805
806 _matched = true;
807 return true;
808}
809
810template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_6>()
811{
812 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
813 CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
814
815 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
816 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
817 CHECK_OR_FALSE(ifm_circle->rank() == 3);
818 CHECK_OR_FALSE((ifm_circle->dim(1).known()));
819
820 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
821 CHECK_OR_FALSE(add_as_variance);
822
824 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
825
826 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
827 // TODO Support regarding broadcast
828 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
829
830 CHECK_OR_FALSE(is_instance_mean_v2(mean_as_variance));
831
832 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
833 CHECK_OR_FALSE(sqdiff);
834
835 loco::Node *ifm_should_be = nullptr;
836 CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
837 CHECK_OR_FALSE(ifm == ifm_should_be);
839 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
840
841 // If const_as_beta has shape of '1 x chennel x (1 or input last dimension)'
842 uint32_t input_channel = ifm_circle->dim(1).value();
843 uint32_t input_last_dim = ifm_circle->dim(2).value();
844 const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
845 CHECK_OR_FALSE(const_as_beta);
846 CHECK_OR_FALSE(const_as_beta->rank() == 3);
848 const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel &&
849 (const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim));
850
851 luci::CircleRsqrt *rsqrt_should_be = nullptr;
852 luci::CircleMean *mean_of_ifm_should_be = nullptr;
853
854 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
855 CHECK_OR_FALSE(mul_as_scaled_mean);
856 CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
857 .with_commutative_args_of(mul_as_scaled_mean));
858 CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
859 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
860
861 // mul_gamma is absent
862 // const_as_gamma assume to be 1.0
863 auto graph = add_as_terminal->graph();
864 const_as_gamma = make_const_one(graph, 1.0f);
865 const_as_gamma->name(add_as_terminal->name() + "/gamma");
866
867 _matched = true;
868 return true;
869}
870
871template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_7>()
872{
873 add_as_terminal = dynamic_cast<luci::CircleAdd *>(reshape_as_terminal->tensor());
874 CHECK_OR_FALSE(add_as_terminal);
875
877 luci::fill(&mul_as_scaled_ifm, &add_neg_mul).with_commutative_args_of(add_as_terminal));
879 luci::fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
880
881 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(add_neg_mul->x());
882 CHECK_OR_FALSE(mul_as_scaled_mean);
883
884 neg_mean = dynamic_cast<luci::CircleNeg *>(mul_as_scaled_mean->x());
885 CHECK_OR_FALSE(neg_mean);
886
887 luci::CircleMul *mul_gamma_should_be = nullptr;
888 luci::CircleNeg *neg_should_be = nullptr;
889
891 luci::fill(&mul_gamma_should_be, &neg_should_be).with_commutative_args_of(mul_as_scaled_mean));
892
893 CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
894 CHECK_OR_FALSE(neg_mean == neg_should_be);
895
896 mean_of_ifm = dynamic_cast<luci::CircleMean *>(neg_mean->x());
897 CHECK_OR_FALSE(mean_of_ifm);
898
899 luci::CircleReshape *reshape_of_ifm_should_be = nullptr;
900 reshape_of_ifm_should_be = dynamic_cast<luci::CircleReshape *>(mean_of_ifm->input());
901 CHECK_OR_FALSE(reshape_of_ifm_should_be == reshape_of_ifm);
902
903 ifm = reshape_of_ifm->tensor();
904 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
905 CHECK_OR_FALSE(ifm_circle);
906
907 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
908 CHECK_OR_FALSE(ifm_circle->rank() == 4);
909 CHECK_OR_FALSE(ifm_circle->dim(3).known());
910 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
911
912 const_as_beta = dynamic_cast<luci::CircleConst *>(add_neg_mul->y());
913 CHECK_OR_FALSE(const_as_beta);
914 CHECK_OR_FALSE(is_unsqueezed_1D(const_as_beta, ifm_channel_depth));
915
916 CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
917 CHECK_OR_FALSE(is_unsqueezed_1D(const_as_gamma, ifm_channel_depth));
918
919 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
920 CHECK_OR_FALSE(add_as_variance);
921
923 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
924 CHECK_OR_FALSE(mean_as_variance);
925
926 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
927 // TODO Support regarding broadcast
928 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
929
930 square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
931 CHECK_OR_FALSE(square);
932
933 sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
934 CHECK_OR_FALSE(sub_2);
935
936 auto mean_of_ifm_should_be = dynamic_cast<luci::CircleMean *>(sub_2->y());
937 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
938
939 auto reshape_of_ifm_should_be_2 = dynamic_cast<luci::CircleReshape *>(sub_2->x());
940 CHECK_OR_FALSE(reshape_of_ifm_should_be_2 == reshape_of_ifm);
941
942 CHECK_OR_FALSE(is_unsqueeze_squeeze_pair(reshape_of_ifm, reshape_as_terminal));
943
944 _matched = true;
945 return true;
946}
947
948bool InstanceNormPattern::matched()
949{
950 if (_matched)
951 return true;
952
953 // Check order is DFS
954
955 switch (_pv)
956 {
957 case PatternVersion::Version_1:
958 return match<PatternVersion::Version_1>();
959 case PatternVersion::Version_2:
960 return match<PatternVersion::Version_2>();
961 case PatternVersion::Version_3:
962 return match<PatternVersion::Version_3>();
963 case PatternVersion::Version_4:
964 return match<PatternVersion::Version_4>();
965 case PatternVersion::Version_5:
966 return match<PatternVersion::Version_5>();
967 case PatternVersion::Version_6:
968 return match<PatternVersion::Version_6>();
969 case PatternVersion::Version_7:
970 return match<PatternVersion::Version_7>();
971
972 default:
973 break;
974 }
975
976 throw std::runtime_error("Invalid InstanceNorm PatternVersion.");
977}
978
979#undef CHECK_OR_FALSE
980
997class FuseInstanceNorm final
998{
999public:
1000 FuseInstanceNorm(const InstanceNormPattern &p) : _p(p) {}
1001
1002public:
1003 void apply(void);
1004
1005private:
1006 template <InstanceNormPattern::PatternVersion> void apply(void);
1007
1008private:
1009 void reshape_gamma_beta(void);
1010 luci::CircleInstanceNorm *create_inst_norm(loco::Graph *graph);
1011
1012private:
1013 const InstanceNormPattern &_p;
1014};
1015
1016void FuseInstanceNorm::reshape_gamma_beta()
1017{
1018 // Version 1 and 3 need to reshape
1019 {
1020 _p.const_as_gamma->rank(1);
1021 _p.const_as_gamma->dim(0).set(_p.const_as_gamma->size<loco::DataType::FLOAT32>());
1022 _p.const_as_beta->rank(1);
1023 _p.const_as_beta->dim(0).set(_p.const_as_beta->size<loco::DataType::FLOAT32>());
1024
1025 _p.const_as_gamma->shape_status(luci::ShapeStatus::UNDEFINED);
1026 _p.const_as_beta->shape_status(luci::ShapeStatus::UNDEFINED);
1027 }
1028}
1029
1030luci::CircleInstanceNorm *FuseInstanceNorm::create_inst_norm(loco::Graph *graph)
1031{
1032 // Make Instance Norm to replace
1033 auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>();
1034 instance_norm->input(_p.ifm);
1035 instance_norm->gamma(_p.const_as_gamma);
1036 instance_norm->beta(_p.const_as_beta);
1037 float epsilon = _p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
1038 instance_norm->epsilon(epsilon);
1039 if (_p.add_as_terminal != nullptr)
1040 {
1041 instance_norm->fusedActivationFunction(_p.add_as_terminal->fusedActivationFunction());
1042 // NOTE unique name should be assigned in export
1043 instance_norm->name("FusedInstanceNorm/" + _p.add_as_terminal->name());
1044 }
1045 else
1046 {
1047 // VERSION_4
1048 assert(_p.div != nullptr);
1049 instance_norm->fusedActivationFunction(_p.div->fusedActivationFunction());
1050 instance_norm->name("FusedInstanceNorm/" + _p.div->name());
1051 }
1052
1053 return instance_norm;
1054}
1055
1056template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_1>()
1057{
1058 auto graph = _p.add_as_terminal->graph();
1059
1060 reshape_gamma_beta();
1061
1062 auto instance_norm = create_inst_norm(graph);
1063
1064 // set origin
1065 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1066 luci::get_origin(_p.mean_of_ifm),
1067 luci::get_origin(_p.sqdiff),
1068 luci::get_origin(_p.mean_as_variance),
1069 luci::get_origin(_p.add_as_variance),
1070 luci::get_origin(_p.rsqrt),
1071 luci::get_origin(_p.mul_gamma),
1072 luci::get_origin(_p.mul_as_scaled_ifm),
1073 luci::get_origin(_p.mul_as_scaled_mean),
1074 luci::get_origin(_p.sub),
1075 luci::get_origin(_p.add_as_terminal)};
1076
1077 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1078
1079 replace(_p.add_as_terminal).with(instance_norm);
1080}
1081
1082template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_2>()
1083{
1084 auto graph = _p.add_as_terminal->graph();
1085
1086 auto instance_norm = create_inst_norm(graph);
1087
1088 // set origin
1089 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1090 luci::get_origin(_p.mean_of_ifm),
1091 luci::get_origin(_p.sqdiff),
1092 luci::get_origin(_p.mean_as_variance),
1093 luci::get_origin(_p.add_as_variance),
1094 luci::get_origin(_p.pow),
1095 luci::get_origin(_p.sub),
1096 luci::get_origin(_p.div),
1097 luci::get_origin(_p.mul_gamma),
1098 luci::get_origin(_p.add_as_terminal)};
1099
1100 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1101
1102 replace(_p.add_as_terminal).with(instance_norm);
1103}
1104
1105template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_3>()
1106{
1107 auto graph = _p.add_as_terminal->graph();
1108
1109 reshape_gamma_beta();
1110
1111 auto instance_norm = create_inst_norm(graph);
1112
1113 // set origin
1114 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1115 luci::get_origin(_p.mean_of_ifm),
1116 luci::get_origin(_p.sub),
1117 luci::get_origin(_p.mean_of_ifm_2),
1118 luci::get_origin(_p.sub_2),
1119 luci::get_origin(_p.square),
1120 luci::get_origin(_p.mean_as_variance),
1121 luci::get_origin(_p.sqrt),
1122 luci::get_origin(_p.add_as_variance),
1123 luci::get_origin(_p.div),
1124 luci::get_origin(_p.mul_gamma),
1125 luci::get_origin(_p.add_as_terminal)};
1126
1127 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1128
1129 replace(_p.add_as_terminal).with(instance_norm);
1130}
1131
1132template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_4>()
1133{
1134 auto graph = _p.div->graph();
1135
1136 auto instance_norm = create_inst_norm(graph);
1137
1138 // set origin
1139 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1140 luci::get_origin(_p.mean_of_ifm),
1141 luci::get_origin(_p.sub),
1142 luci::get_origin(_p.mean_of_ifm_2),
1143 luci::get_origin(_p.sub_2),
1144 luci::get_origin(_p.square),
1145 luci::get_origin(_p.mean_as_variance),
1146 luci::get_origin(_p.sqrt),
1147 luci::get_origin(_p.add_as_variance),
1148 luci::get_origin(_p.div)};
1149
1150 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1151
1152 replace(_p.div).with(instance_norm);
1153}
1154
1155template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
1156{
1157 auto graph = _p.add_as_terminal->graph();
1158
1159 reshape_gamma_beta();
1160
1161 auto instance_norm = create_inst_norm(graph);
1162
1163 // set origin
1164 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1165 luci::get_origin(_p.mean_of_ifm),
1166 luci::get_origin(_p.sqdiff),
1167 luci::get_origin(_p.mean_as_variance),
1168 luci::get_origin(_p.add_as_variance),
1169 luci::get_origin(_p.rsqrt),
1170 luci::get_origin(_p.mul_as_scaled_ifm),
1171 luci::get_origin(_p.mul_as_scaled_mean),
1172 luci::get_origin(_p.sub),
1173 luci::get_origin(_p.add_as_terminal)};
1174
1175 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1176
1177 replace(_p.add_as_terminal).with(instance_norm);
1178}
1179
1180template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_6>()
1181{
1182 auto graph = _p.add_as_terminal->graph();
1183
1184 reshape_gamma_beta();
1185
1186 auto instance_norm = create_inst_norm(graph);
1187
1188 // set origin
1189 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1190 luci::get_origin(_p.mean_of_ifm),
1191 luci::get_origin(_p.sqdiff),
1192 luci::get_origin(_p.mean_as_variance),
1193 luci::get_origin(_p.add_as_variance),
1194 luci::get_origin(_p.rsqrt),
1195 luci::get_origin(_p.mul_as_scaled_ifm),
1196 luci::get_origin(_p.mul_as_scaled_mean),
1197 luci::get_origin(_p.sub),
1198 luci::get_origin(_p.add_as_terminal)};
1199
1200 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1201
1202 replace(_p.add_as_terminal).with(instance_norm);
1203}
1204
1205template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_7>()
1206{
1207 auto graph = _p.reshape_as_terminal->graph();
1208
1209 reshape_gamma_beta();
1210
1211 auto instance_norm = create_inst_norm(graph);
1212
1213 // set origin
1214 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1215 luci::get_origin(_p.reshape_of_ifm),
1216 luci::get_origin(_p.mean_of_ifm),
1217 luci::get_origin(_p.sub_2),
1218 luci::get_origin(_p.square),
1219 luci::get_origin(_p.mean_as_variance),
1220 luci::get_origin(_p.add_as_variance),
1221 luci::get_origin(_p.rsqrt),
1222 luci::get_origin(_p.mul_gamma),
1223 luci::get_origin(_p.neg_mean),
1224 luci::get_origin(_p.mul_as_scaled_ifm),
1225 luci::get_origin(_p.mul_as_scaled_mean),
1226 luci::get_origin(_p.add_neg_mul),
1227 luci::get_origin(_p.add_as_terminal),
1228 luci::get_origin(_p.reshape_as_terminal)};
1229
1230 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1231
1232 replace(_p.reshape_as_terminal).with(instance_norm);
1233}
1234
1235void FuseInstanceNorm::apply()
1236{
1237 assert(_p.matched());
1238
1239 switch (_p.version())
1240 {
1241 case InstanceNormPattern::PatternVersion::Version_1:
1242 apply<InstanceNormPattern::PatternVersion::Version_1>();
1243 break;
1244 case InstanceNormPattern::PatternVersion::Version_2:
1245 apply<InstanceNormPattern::PatternVersion::Version_2>();
1246 break;
1247 case InstanceNormPattern::PatternVersion::Version_3:
1248 apply<InstanceNormPattern::PatternVersion::Version_3>();
1249 break;
1250 case InstanceNormPattern::PatternVersion::Version_4:
1251 apply<InstanceNormPattern::PatternVersion::Version_4>();
1252 break;
1253 case InstanceNormPattern::PatternVersion::Version_5:
1254 apply<InstanceNormPattern::PatternVersion::Version_5>();
1255 break;
1256 case InstanceNormPattern::PatternVersion::Version_6:
1257 apply<InstanceNormPattern::PatternVersion::Version_6>();
1258 break;
1259 case InstanceNormPattern::PatternVersion::Version_7:
1260 apply<InstanceNormPattern::PatternVersion::Version_7>();
1261 break;
1262
1263 default:
1264 break;
1265 }
1266}
1267
1268} // namespace
1269
1270namespace
1271{
1272
1273class PostFusion final
1274{
1275public:
1276 PostFusion(luci::CircleInstanceNorm *inst_norm) : _inst_norm(inst_norm) {}
1277
1278private:
1279 uint32_t input_channel(void);
1280
1281 luci::CircleConst *match_const_channel(luci::CircleConst *, uint32_t);
1282 bool match_const_gamma_channel(void);
1283 bool match_const_beta_channel(void);
1284
1285public:
1286 bool process(void);
1287
1288private:
1289 luci::CircleInstanceNorm *_inst_norm = nullptr;
1290};
1291
1295uint32_t PostFusion::input_channel(void)
1296{
1297 auto input = dynamic_cast<luci::CircleNode *>(_inst_norm->input());
1298 if (input == nullptr)
1299 return 0;
1300 if (input->shape_status() != luci::ShapeStatus::VALID)
1301 return 0;
1302
1303 auto input_rank = input->rank();
1304 if (input_rank < 1)
1305 return 0;
1306
1307 if (input_rank == 3)
1308 {
1309 // use dim 1
1310 return input->dim(1).value();
1311 }
1312 // assume channel-last
1313 return input->dim(input_rank - 1).value();
1314}
1315
1319luci::CircleConst *PostFusion::match_const_channel(luci::CircleConst *input_const, uint32_t C)
1320{
1321 luci::CircleConst *new_input_const = nullptr;
1322
1323 auto input_chn = input_const->dim(0).value();
1324 if (input_chn == 1 && input_chn != C)
1325 {
1326 float value = input_const->at<loco::DataType::FLOAT32>(0);
1327 auto clone = luci::clone_node(input_const, input_const->graph());
1328
1329 new_input_const = luci::must_cast<luci::CircleConst *>(clone);
1330 new_input_const->rank(1);
1331 new_input_const->dim(0).set(C);
1332 new_input_const->size<loco::DataType::FLOAT32>(C);
1333 for (uint32_t c = 0; c < C; ++c)
1334 new_input_const->at<loco::DataType::FLOAT32>(c) = value;
1335 }
1336
1337 return new_input_const;
1338}
1339
1343bool PostFusion::match_const_gamma_channel(void)
1344{
1345 auto const_as_gamma = dynamic_cast<luci::CircleConst *>(_inst_norm->gamma());
1346 if (const_as_gamma == nullptr)
1347 return false;
1348
1349 auto C = input_channel();
1350 if (C == 0)
1351 return false;
1352
1353 auto new_const_as_gamma = match_const_channel(const_as_gamma, C);
1354 if (new_const_as_gamma == nullptr)
1355 return false;
1356
1357 _inst_norm->gamma(new_const_as_gamma);
1358
1359 return true;
1360}
1361
1365bool PostFusion::match_const_beta_channel(void)
1366{
1367 auto const_as_beta = dynamic_cast<luci::CircleConst *>(_inst_norm->beta());
1368 if (const_as_beta == nullptr)
1369 return false;
1370
1371 auto C = input_channel();
1372 if (C == 0)
1373 return false;
1374
1375 auto new_const_as_beta = match_const_channel(const_as_beta, C);
1376 if (new_const_as_beta == nullptr)
1377 return false;
1378
1379 _inst_norm->beta(new_const_as_beta);
1380
1381 return true;
1382}
1383
1384bool PostFusion::process(void)
1385{
1386 bool changed = false;
1387
1388 if (match_const_gamma_channel())
1389 changed = true;
1390 if (match_const_beta_channel())
1391 changed = true;
1392
1393 return changed;
1394}
1395
1396} // namespace
1397
1398namespace
1399{
1400
1401bool is_add_input_mul_const(luci::CircleAdd *add)
1402{
1403 luci::CircleMul *p_mul = nullptr;
1404 luci::CircleConst *p_const = nullptr;
1405
1406 return luci::fill(&p_mul, &p_const).with_commutative_args_of(add);
1407}
1408
1409bool is_add_input_mul_sub3d(luci::CircleAdd *add)
1410{
1411 luci::CircleMul *p_mul = nullptr;
1412 luci::CircleSub *p_sub = nullptr;
1413
1414 if (!luci::fill(&p_mul, &p_sub).with_commutative_args_of(add))
1415 return false;
1416
1417 auto sub = dynamic_cast<luci::CircleSub *>(add->y());
1418 if (sub == nullptr)
1419 return false;
1420
1421 auto const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
1422 if (const_as_beta == nullptr || const_as_beta->rank() != 3)
1423 return false;
1424
1425 return true;
1426}
1427
1428bool fuse_instance_norm(luci::CircleAdd *add)
1429{
1430 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1;
1431
1432 if (is_add_input_mul_const(add))
1433 pv = InstanceNormPattern::PatternVersion::Version_2;
1434 else if (is_add_input_mul_sub3d(add))
1435 pv = InstanceNormPattern::PatternVersion::Version_6;
1436
1437 InstanceNormPattern pattern(add, pv);
1438 if (pattern.matched())
1439 {
1440 FuseInstanceNorm fuse(pattern);
1441 fuse.apply();
1442 return true;
1443 }
1444
1445 if (pv == InstanceNormPattern::PatternVersion::Version_1)
1446 {
1447 // if Version_1 failed, try with Version_5
1448 pv = InstanceNormPattern::PatternVersion::Version_5;
1449 InstanceNormPattern pattern(add, pv);
1450 if (pattern.matched())
1451 {
1452 FuseInstanceNorm fuse(pattern);
1453 fuse.apply();
1454 return true;
1455 }
1456 }
1457 else if (pv == InstanceNormPattern::PatternVersion::Version_2)
1458 {
1459 // if Version_2 failed, try with Version_3
1460 pv = InstanceNormPattern::PatternVersion::Version_3;
1461 InstanceNormPattern pattern(add, pv);
1462 if (pattern.matched())
1463 {
1464 FuseInstanceNorm fuse(pattern);
1465 fuse.apply();
1466 return true;
1467 }
1468 }
1469
1470 return false;
1471}
1472
1473bool fuse_instance_norm(luci::CircleDiv *div)
1474{
1475 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_4;
1476
1477 InstanceNormPattern pattern(div, pv);
1478 if (pattern.matched())
1479 {
1480 FuseInstanceNorm fuse(pattern);
1481 fuse.apply();
1482 return true;
1483 }
1484
1485 return false;
1486}
1487
1488bool fuse_instance_norm(luci::CircleReshape *reshape)
1489{
1490 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_7;
1491
1492 InstanceNormPattern pattern(reshape, pv);
1493 if (pattern.matched())
1494 {
1495 FuseInstanceNorm fuse(pattern);
1496 fuse.apply();
1497 return true;
1498 }
1499
1500 return false;
1501}
1502
1503bool post_fusion(luci::CircleInstanceNorm *inst_norm)
1504{
1505 PostFusion postfusion(inst_norm);
1506
1507 return postfusion.process();
1508}
1509
1510} // namespace
1511
1512namespace luci
1513{
1514
1516{
1517 bool changed = false;
1518
1519 // Check Version_1, Version_2, Version_3, Version_5, Version_6
1520 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1521 {
1522 auto add = dynamic_cast<luci::CircleAdd *>(node);
1523 if (not add)
1524 continue;
1525
1526 if (fuse_instance_norm(add))
1527 changed = true;
1528 }
1529
1530 // Check Version_4(from DIV) if MUL-ADD pattern is not found
1531 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1532 {
1533 auto div = dynamic_cast<luci::CircleDiv *>(node);
1534 if (not div)
1535 continue;
1536
1537 if (fuse_instance_norm(div))
1538 changed = true;
1539 }
1540
1541 // Check Version_7(from Reshape) if other versions not found
1542 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1543 {
1544 auto reshape = dynamic_cast<luci::CircleReshape *>(node);
1545 if (not reshape)
1546 continue;
1547
1548 if (fuse_instance_norm(reshape))
1549 changed = true;
1550 }
1551
1552 // Post processing of FuseInstanceNorm
1553 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1554 {
1555 auto inst_norm = dynamic_cast<luci::CircleInstanceNorm *>(node);
1556 if (not inst_norm)
1557 continue;
1558
1560 changed = true;
1561 }
1562
1563 return changed;
1564}
1565
1566} // namespace luci
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:55
Graph * graph(void)
Definition Node.h:71
void with(Node *into) const
Definition Node.cpp:66
ADD in Circle.
Definition CircleAdd.h:34
Class to build tensor data.
Definition CircleConst.h:35
const loco::DataTypeImpl< DT >::Type & at(uint32_t n) const
uint32_t size(void) const
DIV in Circle.
Definition CircleDiv.h:37
INSTANCE_NORM in Circle.
loco::Node * input(void) const
MEAN in Circle.
Definition CircleMean.h:32
bool keep_dims(void) const
Definition CircleMean.h:41
loco::Node * input(void) const
Definition CircleMean.h:34
loco::Node * reduction_indices(void) const
Definition CircleMean.h:37
MUL in Circle.
Definition CircleMul.h:34
NEG in Circle.
Definition CircleNeg.h:32
POW in Circle.
Definition CirclePow.h:32
RESHAPE in Circle.
loco::Node * tensor(void) const
RSQRT in Circle.
Definition CircleRsqrt.h:32
SQRT in Circle.
Definition CircleSqrt.h:32
SQUARE in Circle.
SQUARED_DIFFERENCE in Circle.
SUB in Circle.
Definition CircleSub.h:34
#define CHECK_OR_FALSE(condition)
bool is_instance_mean_v1(luci::CircleMean *mean)
bool is_unsqueeze_squeeze_pair(luci::CircleReshape *begin_reshape, luci::CircleReshape *terminal_reshape)
bool is_instance_mean_v2(luci::CircleMean *mean)
bool is_unsqueezed_1D(luci::CircleConst *node, uint32_t depth)
bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
C
Definition infer.py:52
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
T must_cast(FeatureEncoder *node)
A helper dynamic_cast that throws when failed.
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
Definition NodeFiller.h:72
T must_cast(loco::Node *node)
CircleNode * clone_node(const CircleNode *node, loco::Graph *graph)
Return a new cloned CircleNode object with same attributes value of node to graph.
luci::CircleConst * clone(luci::CircleConst *node)
Return cloned object of CircleConst node.
T square(T value)
Definition Loss.h:37
version
Definition setup.py:159
Configuration p
bool run(loco::Graph *g) final
Run the pass.