ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FuseInstanceNormPass.cpp File Reference
#include "luci/Pass/FuseInstanceNormPass.h"
#include "helpers/NodeFiller.h"
#include "FuseInstanceNormPassInternal.h"
#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Service/CircleNodeClone.h>
#include <cassert>
#include <set>
#include <optional>

Go to the source code of this file.

Namespaces

namespace  luci
 

Macros

#define CHECK_OR_FALSE(condition)
 

Functions

bool is_unsqueezed_1D (luci::CircleConst *node, uint32_t depth)
 
bool is_unsqueeze_squeeze_pair (luci::CircleReshape *begin_reshape, luci::CircleReshape *terminal_reshape)
 
bool is_instance_mean_v1 (luci::CircleMean *mean)
 
bool is_instance_mean_v2 (luci::CircleMean *mean)
 
bool is_1D_float32_const (const luci::CircleConst *node, uint32_t channel_size)
 

Macro Definition Documentation

◆ CHECK_OR_FALSE

#define CHECK_OR_FALSE (   condition)
Value:
if (not(condition)) \
return false;

Definition at line 32 of file FuseInstanceNormPass.cpp.

38{
39 const auto rank = node->rank();
40 std::optional<uint32_t> depth_axis;
41 for (uint32_t axis = 0; axis < rank; ++axis)
42 {
43 if (node->dim(axis).value() != 1)
44 {
45 // only one axis can be other than 1
46 if (depth_axis.has_value())
47 {
48 return false;
49 }
50 depth_axis = axis;
51 }
52 }
53 if (!depth_axis.has_value())
54 {
55 return false;
56 }
57 return node->dim(depth_axis.value()).value() == depth;
58}
59
64 luci::CircleReshape *terminal_reshape)
65{
66 auto const begin_reshape_ifm = dynamic_cast<luci::CircleNode *>(begin_reshape->tensor());
67 CHECK_OR_FALSE(begin_reshape_ifm);
68
69 // check last axis
70 CHECK_OR_FALSE((begin_reshape_ifm->rank() + 1) == begin_reshape->rank());
71
72 // check unchanged part of begin_shape
73 for (uint32_t axis = 0; axis < begin_reshape_ifm->rank(); ++axis)
74 {
75 // skip dynamic cases
76 CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).known() && begin_reshape->dim(axis).known());
77 CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).value() == begin_reshape->dim(axis).value());
78 }
79 // check last axis
80 CHECK_OR_FALSE(begin_reshape->dim(begin_reshape->rank() - 1) == 1);
81
82 auto const terminal_reshape_ifm = dynamic_cast<luci::CircleNode *>(terminal_reshape->tensor());
83 CHECK_OR_FALSE(terminal_reshape_ifm);
84
85 CHECK_OR_FALSE(terminal_reshape_ifm->rank() == terminal_reshape->rank() + 1);
86
87 // check last axis
88 CHECK_OR_FALSE(terminal_reshape_ifm->dim(begin_reshape->rank() - 1) == 1);
89
90 // check unchanged part of terminal_reshape
91 for (uint32_t axis = 0; axis < terminal_reshape->rank(); ++axis)
92 {
93 // skip dynamic cases
94 CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).known() && terminal_reshape->dim(axis).known());
95 CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).value() == terminal_reshape->dim(axis).value());
96 }
97
98 return true;
99}
100
102{
103 //
104 // CHECK 1) input is rank 4
105 //
107 if (input->shape_status() != luci::ShapeStatus::VALID)
108 return false;
109 if (input->rank() != 4)
110 return false;
111
112 //
113 // CHECK 2) 'reduction indices' is CircleConst of value [1,2], that is HW of NHWC
114 //
115 // TODO Support equivalent case, like [-3,-2]
116 // TODO Support non-Const case?
117 // TODO What if input is NCHW format in Circle?
118 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
119 if (not red_indices)
120 return false;
121 if (red_indices->rank() != 1)
122 return false;
123 std::set<int32_t> red_indices_set;
124 {
125 // TODO Currently only support S32, support other types
126 assert(red_indices->dtype() == loco::DataType::S32);
127 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
128 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
129 }
130 if (red_indices_set.size() != 2)
131 return false;
132 if (red_indices_set.find(1) == red_indices_set.end())
133 return false;
134 if (red_indices_set.find(2) == red_indices_set.end())
135 return false;
136
137 //
138 // CHECK 3) keep_dims == true (?)
139 //
140 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
141 // TODO Check this fact, and if true, return true regardless of keep_dims
142 return mean->keep_dims();
143}
144
146{
147 //
148 // CHECK 1) input is rank 3
149 //
151 if (input->shape_status() != luci::ShapeStatus::VALID)
152 return false;
153 if (input->rank() != 3)
154 return false;
155
156 //
157 // CHECK 2) 'reduction indices' is CircleConst of value [2], that is last dim of rank 3
158 //
159 // TODO Support non-Const case?
160 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
161 if (not red_indices)
162 return false;
163 if (red_indices->rank() != 1)
164 return false;
165 std::set<int32_t> red_indices_set;
166 {
167 // TODO Currently only support S32, support other types
168 assert(red_indices->dtype() == loco::DataType::S32);
169 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
170 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
171 }
172 if (red_indices_set.size() != 1)
173 return false;
174 if (red_indices_set.find(2) == red_indices_set.end())
175 return false;
176
177 //
178 // CHECK 3) keep_dims == true (?)
179 //
180 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
181 // TODO Check this fact, and if true, return true regardless of keep_dims
182 return mean->keep_dims();
183}
184
186bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
187{
188 if (node->rank() != 1)
189 return false;
190
191 if (node->dim(0).value() != channel_size)
192 return false;
193
194 if (node->dtype() != loco::DataType::FLOAT32)
195 return false;
196
197 if (node->size<loco::DataType::FLOAT32>() != channel_size)
198 return false;
199
200 return true;
201}
202
203// Helper to fuse Instance Norm
204namespace
205{
206
482class InstanceNormPattern final
483{
484public:
485 enum PatternVersion
486 {
487 Version_Unknown,
488 Version_1,
489 Version_2,
490 Version_3,
491 Version_4,
492 Version_5,
493 Version_6, // For only 3D I/O
494 Version_7,
495 };
496
497 InstanceNormPattern(luci::CircleAdd *candidate, PatternVersion pv)
498 {
499 assert(candidate);
500 add_as_terminal = candidate;
501 _pv = pv;
502 }
503
504 InstanceNormPattern(luci::CircleDiv *candidate, PatternVersion pv)
505 {
506 assert(candidate);
507 div = candidate;
508 _pv = pv;
509 }
510
511 InstanceNormPattern(luci::CircleReshape *candidate, PatternVersion pv)
512 {
513 assert(candidate);
514 reshape_as_terminal = candidate;
515 _pv = pv;
516 }
517
518private:
519 bool condition_common_1_5(uint32_t ifm_channel_depth);
520 bool condition_common_3_4();
521
522private:
523 template <enum PatternVersion> bool match();
524
525public:
526 bool matched();
527 bool matched() const { return _matched; }
528
529 PatternVersion version() const { return _pv; }
530
531public:
532 // Context
533 loco::Node *ifm = nullptr;
534 luci::CircleReshape *reshape_of_ifm = nullptr;
535 luci::CircleMean *mean_of_ifm = nullptr;
536 luci::CircleMean *mean_of_ifm_2 = nullptr;
537 luci::CircleMean *mean_of_reshape = nullptr;
538 luci::CircleSquaredDifference *sqdiff = nullptr;
539 luci::CircleSquare *square = nullptr;
540 luci::CircleMean *mean_as_variance = nullptr;
541 luci::CircleConst *const_as_epsilon = nullptr;
542 luci::CircleAdd *add_as_variance = nullptr;
543 luci::CircleAdd *add_neg_mul = nullptr;
544 luci::CircleRsqrt *rsqrt = nullptr;
545 luci::CircleConst *const_as_gamma = nullptr;
546 luci::CircleMul *mul_gamma = nullptr;
547 luci::CircleMul *mul_as_scaled_ifm = nullptr;
548 luci::CircleMul *mul_as_scaled_mean = nullptr;
549 luci::CircleMul *mul_as_scaled_reshape = nullptr;
550 luci::CircleConst *const_as_beta = nullptr;
551 luci::CircleSub *sub = nullptr;
552 luci::CircleSub *sub_2 = nullptr;
553 luci::CircleAdd *add_as_terminal = nullptr;
554 luci::CirclePow *pow = nullptr;
555 luci::CircleSqrt *sqrt = nullptr;
556 luci::CircleDiv *div = nullptr;
557 luci::CircleConst *reshape_terminal_target_shape = nullptr;
558 luci::CircleReshape *reshape_as_terminal = nullptr;
559 luci::CircleNeg *neg_mean = nullptr;
560
561private:
562 bool _matched = false;
563 PatternVersion _pv;
564};
565
566bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
567{
568 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
569 CHECK_OR_FALSE(add_as_variance);
570
572 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
573
574 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
575 // TODO Support regarding broadcast
576 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
577
578 CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
579
580 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
581 CHECK_OR_FALSE(sqdiff);
582
583 loco::Node *ifm_should_be = nullptr;
584 CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
585 CHECK_OR_FALSE(ifm == ifm_should_be);
587 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
588
589 const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
590 CHECK_OR_FALSE(const_as_beta);
591 CHECK_OR_FALSE(is_unsqueezed_1D(const_as_beta, ifm_channel_depth));
592
593 return true;
594}
595
596bool InstanceNormPattern::condition_common_3_4()
597{
598 // check left sub
599 ifm = sub->x();
600 CHECK_OR_FALSE(ifm);
601
603 CHECK_OR_FALSE(ifm_node->rank() == 4);
604 CHECK_OR_FALSE(ifm_node->dim(3).known());
605
606 mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
607 CHECK_OR_FALSE(mean_of_ifm);
608 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
609
610 // continue search from add_as_variance
611 CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance));
612 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
613 // TODO Support regarding broadcast
614 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
615
616 mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x());
617 CHECK_OR_FALSE(mean_as_variance);
618
619 square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
620 CHECK_OR_FALSE(square);
621
622 sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
623 CHECK_OR_FALSE(sub_2);
624 CHECK_OR_FALSE(ifm == sub_2->x());
625
626 mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y());
627 CHECK_OR_FALSE(mean_of_ifm_2);
628 CHECK_OR_FALSE(ifm == mean_of_ifm_2->input());
629
630 loco::Node *ifm_should_be = nullptr;
631 luci::CircleMean *mean_of_ifm_2_should_be = nullptr;
633 luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
634 CHECK_OR_FALSE(ifm == ifm_should_be);
635 CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be);
636
637 return true;
638}
639
640template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
641{
642 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
643 CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
644
645 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
646 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
647 CHECK_OR_FALSE(ifm_circle->rank() == 4);
648 CHECK_OR_FALSE(ifm_circle->dim(3).known());
649 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
650
651 CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
652
653 CHECK_OR_FALSE(is_unsqueezed_1D(const_as_gamma, ifm_channel_depth));
654
655 CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
656
657 luci::CircleMul *mul_gamma_should_be = nullptr;
658 luci::CircleMean *mean_of_ifm_should_be = nullptr;
659
660 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
661 CHECK_OR_FALSE(mul_as_scaled_mean);
662 CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
663 .with_commutative_args_of(mul_as_scaled_mean));
664 CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
665 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
666
667 _matched = true;
668 return true;
669}
670
671template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_2>()
672{
673 CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
674 CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
675
676 sub = dynamic_cast<luci::CircleSub *>(div->x());
677 CHECK_OR_FALSE(sub);
678
679 ifm = sub->x();
680 CHECK_OR_FALSE(ifm);
681
683 CHECK_OR_FALSE(ifm_node->rank() == 4);
684 CHECK_OR_FALSE(ifm_node->dim(3).known());
685 uint32_t ifm_channel_depth = ifm_node->dim(3).value();
686
687 mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
688 CHECK_OR_FALSE(mean_of_ifm);
689
690 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
691
692 pow = dynamic_cast<luci::CirclePow *>(div->y());
693 CHECK_OR_FALSE(pow);
694
695 add_as_variance = dynamic_cast<luci::CircleAdd *>(pow->x());
696 CHECK_OR_FALSE(add_as_variance);
697
698 luci::CircleConst *zero_point_five = dynamic_cast<luci::CircleConst *>(pow->y());
699 CHECK_OR_FALSE(zero_point_five);
700 CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
701 // TODO Support regarding broadcast
702 CHECK_OR_FALSE(zero_point_five->size<loco::DataType::FLOAT32>() == 1);
703 CHECK_OR_FALSE(zero_point_five->at<loco::DataType::FLOAT32>(0) == 0.5);
704
706 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
707 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
708 // TODO Support regarding broadcast
709 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
710
711 CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
712
713 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
714 CHECK_OR_FALSE(sqdiff);
715
716 loco::Node *ifm_should_be = nullptr;
717 luci::CircleMean *mean_of_ifm_should_be = nullptr;
719 luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
720 CHECK_OR_FALSE(ifm == ifm_should_be);
721 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
722
723 // Check for channel size
724 CHECK_OR_FALSE(is_1D_float32_const(const_as_gamma, ifm_channel_depth));
725 CHECK_OR_FALSE(is_1D_float32_const(const_as_beta, ifm_channel_depth));
726
727 _matched = true;
728 return true;
729}
730
731template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_3>()
732{
733 CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
734 CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
735 CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
736
737 CHECK_OR_FALSE(condition_common_3_4());
738
739 _matched = true;
740 return true;
741}
742
743luci::CircleConst *make_const_one(loco::Graph *graph, float value)
744{
745 auto const_one = graph->nodes()->create<luci::CircleConst>();
746 const_one->dtype(loco::DataType::FLOAT32);
747 const_one->rank(1);
748 const_one->size<loco::DataType::FLOAT32>(1);
749 const_one->at<loco::DataType::FLOAT32>(0) = value;
750 return const_one;
751}
752
753template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_4>()
754{
755 CHECK_OR_FALSE(div);
756 CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
757
758 CHECK_OR_FALSE(condition_common_3_4());
759
760 assert(const_as_gamma == nullptr);
761 assert(const_as_beta == nullptr);
762 assert(mul_gamma == nullptr);
763 assert(add_as_terminal == nullptr);
764
765 // create 1.0 gamma and 0.0 beta
766 auto graph = div->graph();
767 const_as_gamma = make_const_one(graph, 1.0f);
768 const_as_beta = make_const_one(graph, 0.0f);
769 const_as_gamma->name(div->name() + "/gamma");
770 const_as_beta->name(div->name() + "/beta");
771
772 _matched = true;
773 return true;
774}
775
776template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
777{
778 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
779 CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
780
781 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
782 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
783 CHECK_OR_FALSE(ifm_circle->rank() == 4);
784 CHECK_OR_FALSE(ifm_circle->dim(3).known());
785 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
786
787 CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
788
789 luci::CircleRsqrt *rsqrt_should_be = nullptr;
790 luci::CircleMean *mean_of_ifm_should_be = nullptr;
791
792 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
793 CHECK_OR_FALSE(mul_as_scaled_mean);
794 CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
795 .with_commutative_args_of(mul_as_scaled_mean));
796 CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
797 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
798
799 // mul_gamma is absent
800 // const_as_gamma assume to be 1.0
801 auto graph = add_as_terminal->graph();
802 const_as_gamma = make_const_one(graph, 1.0f);
803 const_as_gamma->name(add_as_terminal->name() + "/gamma");
804
805 _matched = true;
806 return true;
807}
808
809template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_6>()
810{
811 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
812 CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
813
814 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
815 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
816 CHECK_OR_FALSE(ifm_circle->rank() == 3);
817 CHECK_OR_FALSE((ifm_circle->dim(1).known()));
818
819 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
820 CHECK_OR_FALSE(add_as_variance);
821
823 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
824
825 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
826 // TODO Support regarding broadcast
827 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
828
829 CHECK_OR_FALSE(is_instance_mean_v2(mean_as_variance));
830
831 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
832 CHECK_OR_FALSE(sqdiff);
833
834 loco::Node *ifm_should_be = nullptr;
835 CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
836 CHECK_OR_FALSE(ifm == ifm_should_be);
838 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
839
840 // If const_as_beta has shape of '1 x chennel x (1 or input last dimension)'
841 uint32_t input_channel = ifm_circle->dim(1).value();
842 uint32_t input_last_dim = ifm_circle->dim(2).value();
843 const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
844 CHECK_OR_FALSE(const_as_beta);
845 CHECK_OR_FALSE(const_as_beta->rank() == 3);
847 const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel &&
848 (const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim));
849
850 luci::CircleRsqrt *rsqrt_should_be = nullptr;
851 luci::CircleMean *mean_of_ifm_should_be = nullptr;
852
853 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
854 CHECK_OR_FALSE(mul_as_scaled_mean);
855 CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
856 .with_commutative_args_of(mul_as_scaled_mean));
857 CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
858 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
859
860 // mul_gamma is absent
861 // const_as_gamma assume to be 1.0
862 auto graph = add_as_terminal->graph();
863 const_as_gamma = make_const_one(graph, 1.0f);
864 const_as_gamma->name(add_as_terminal->name() + "/gamma");
865
866 _matched = true;
867 return true;
868}
869
870template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_7>()
871{
872 add_as_terminal = dynamic_cast<luci::CircleAdd *>(reshape_as_terminal->tensor());
873 CHECK_OR_FALSE(add_as_terminal);
874
876 luci::fill(&mul_as_scaled_ifm, &add_neg_mul).with_commutative_args_of(add_as_terminal));
878 luci::fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
879
880 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(add_neg_mul->x());
881 CHECK_OR_FALSE(mul_as_scaled_mean);
882
883 neg_mean = dynamic_cast<luci::CircleNeg *>(mul_as_scaled_mean->x());
884 CHECK_OR_FALSE(neg_mean);
885
886 luci::CircleMul *mul_gamma_should_be = nullptr;
887 luci::CircleNeg *neg_should_be = nullptr;
888
890 luci::fill(&mul_gamma_should_be, &neg_should_be).with_commutative_args_of(mul_as_scaled_mean));
891
892 CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
893 CHECK_OR_FALSE(neg_mean == neg_should_be);
894
895 mean_of_ifm = dynamic_cast<luci::CircleMean *>(neg_mean->x());
896 CHECK_OR_FALSE(mean_of_ifm);
897
898 luci::CircleReshape *reshape_of_ifm_should_be = nullptr;
899 reshape_of_ifm_should_be = dynamic_cast<luci::CircleReshape *>(mean_of_ifm->input());
900 CHECK_OR_FALSE(reshape_of_ifm_should_be == reshape_of_ifm);
901
902 ifm = reshape_of_ifm->tensor();
903 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
904 CHECK_OR_FALSE(ifm_circle);
905
906 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
907 CHECK_OR_FALSE(ifm_circle->rank() == 4);
908 CHECK_OR_FALSE(ifm_circle->dim(3).known());
909 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
910
911 const_as_beta = dynamic_cast<luci::CircleConst *>(add_neg_mul->y());
912 CHECK_OR_FALSE(const_as_beta);
913 CHECK_OR_FALSE(is_unsqueezed_1D(const_as_beta, ifm_channel_depth));
914
915 CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
916 CHECK_OR_FALSE(is_unsqueezed_1D(const_as_gamma, ifm_channel_depth));
917
918 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
919 CHECK_OR_FALSE(add_as_variance);
920
922 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
923 CHECK_OR_FALSE(mean_as_variance);
924
925 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
926 // TODO Support regarding broadcast
927 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
928
929 square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
930 CHECK_OR_FALSE(square);
931
932 sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
933 CHECK_OR_FALSE(sub_2);
934
935 auto mean_of_ifm_should_be = dynamic_cast<luci::CircleMean *>(sub_2->y());
936 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
937
938 auto reshape_of_ifm_should_be_2 = dynamic_cast<luci::CircleReshape *>(sub_2->x());
939 CHECK_OR_FALSE(reshape_of_ifm_should_be_2 == reshape_of_ifm);
940
941 CHECK_OR_FALSE(is_unsqueeze_squeeze_pair(reshape_of_ifm, reshape_as_terminal));
942
943 _matched = true;
944 return true;
945}
946
947bool InstanceNormPattern::matched()
948{
949 if (_matched)
950 return true;
951
952 // Check order is DFS
953
954 switch (_pv)
955 {
956 case PatternVersion::Version_1:
957 return match<PatternVersion::Version_1>();
958 case PatternVersion::Version_2:
959 return match<PatternVersion::Version_2>();
960 case PatternVersion::Version_3:
961 return match<PatternVersion::Version_3>();
962 case PatternVersion::Version_4:
963 return match<PatternVersion::Version_4>();
964 case PatternVersion::Version_5:
965 return match<PatternVersion::Version_5>();
966 case PatternVersion::Version_6:
967 return match<PatternVersion::Version_6>();
968 case PatternVersion::Version_7:
969 return match<PatternVersion::Version_7>();
970
971 default:
972 break;
973 }
974
975 throw std::runtime_error("Invalid InstanceNorm PatternVersion.");
976}
977
978#undef CHECK_OR_FALSE
979
996class FuseInstanceNorm final
997{
998public:
999 FuseInstanceNorm(const InstanceNormPattern &p) : _p(p) {}
1000
1001public:
1002 void apply(void);
1003
1004private:
1005 template <InstanceNormPattern::PatternVersion> void apply(void);
1006
1007private:
1008 void reshape_gamma_beta(void);
1009 luci::CircleInstanceNorm *create_inst_norm(loco::Graph *graph);
1010
1011private:
1012 const InstanceNormPattern &_p;
1013};
1014
1015void FuseInstanceNorm::reshape_gamma_beta()
1016{
1017 // Version 1 and 3 need to reshape
1018 {
1019 _p.const_as_gamma->rank(1);
1020 _p.const_as_gamma->dim(0).set(_p.const_as_gamma->size<loco::DataType::FLOAT32>());
1021 _p.const_as_beta->rank(1);
1022 _p.const_as_beta->dim(0).set(_p.const_as_beta->size<loco::DataType::FLOAT32>());
1023
1024 _p.const_as_gamma->shape_status(luci::ShapeStatus::UNDEFINED);
1025 _p.const_as_beta->shape_status(luci::ShapeStatus::UNDEFINED);
1026 }
1027}
1028
1029luci::CircleInstanceNorm *FuseInstanceNorm::create_inst_norm(loco::Graph *graph)
1030{
1031 // Make Instance Norm to replace
1032 auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>();
1033 instance_norm->input(_p.ifm);
1034 instance_norm->gamma(_p.const_as_gamma);
1035 instance_norm->beta(_p.const_as_beta);
1036 float epsilon = _p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
1037 instance_norm->epsilon(epsilon);
1038 if (_p.add_as_terminal != nullptr)
1039 {
1040 instance_norm->fusedActivationFunction(_p.add_as_terminal->fusedActivationFunction());
1041 // NOTE unique name should be assigned in export
1042 instance_norm->name("FusedInstanceNorm/" + _p.add_as_terminal->name());
1043 }
1044 else
1045 {
1046 // VERSION_4
1047 assert(_p.div != nullptr);
1048 instance_norm->fusedActivationFunction(_p.div->fusedActivationFunction());
1049 instance_norm->name("FusedInstanceNorm/" + _p.div->name());
1050 }
1051
1052 return instance_norm;
1053}
1054
1055template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_1>()
1056{
1057 auto graph = _p.add_as_terminal->graph();
1058
1059 reshape_gamma_beta();
1060
1061 auto instance_norm = create_inst_norm(graph);
1062
1063 // set origin
1064 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1065 luci::get_origin(_p.mean_of_ifm),
1066 luci::get_origin(_p.sqdiff),
1067 luci::get_origin(_p.mean_as_variance),
1068 luci::get_origin(_p.add_as_variance),
1069 luci::get_origin(_p.rsqrt),
1070 luci::get_origin(_p.mul_gamma),
1071 luci::get_origin(_p.mul_as_scaled_ifm),
1072 luci::get_origin(_p.mul_as_scaled_mean),
1073 luci::get_origin(_p.sub),
1074 luci::get_origin(_p.add_as_terminal)};
1075
1076 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1077
1078 replace(_p.add_as_terminal).with(instance_norm);
1079}
1080
1081template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_2>()
1082{
1083 auto graph = _p.add_as_terminal->graph();
1084
1085 auto instance_norm = create_inst_norm(graph);
1086
1087 // set origin
1088 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1089 luci::get_origin(_p.mean_of_ifm),
1090 luci::get_origin(_p.sqdiff),
1091 luci::get_origin(_p.mean_as_variance),
1092 luci::get_origin(_p.add_as_variance),
1093 luci::get_origin(_p.pow),
1094 luci::get_origin(_p.sub),
1095 luci::get_origin(_p.div),
1096 luci::get_origin(_p.mul_gamma),
1097 luci::get_origin(_p.add_as_terminal)};
1098
1099 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1100
1101 replace(_p.add_as_terminal).with(instance_norm);
1102}
1103
1104template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_3>()
1105{
1106 auto graph = _p.add_as_terminal->graph();
1107
1108 reshape_gamma_beta();
1109
1110 auto instance_norm = create_inst_norm(graph);
1111
1112 // set origin
1113 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1114 luci::get_origin(_p.mean_of_ifm),
1115 luci::get_origin(_p.sub),
1116 luci::get_origin(_p.mean_of_ifm_2),
1117 luci::get_origin(_p.sub_2),
1118 luci::get_origin(_p.square),
1119 luci::get_origin(_p.mean_as_variance),
1120 luci::get_origin(_p.sqrt),
1121 luci::get_origin(_p.add_as_variance),
1122 luci::get_origin(_p.div),
1123 luci::get_origin(_p.mul_gamma),
1124 luci::get_origin(_p.add_as_terminal)};
1125
1126 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1127
1128 replace(_p.add_as_terminal).with(instance_norm);
1129}
1130
1131template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_4>()
1132{
1133 auto graph = _p.div->graph();
1134
1135 auto instance_norm = create_inst_norm(graph);
1136
1137 // set origin
1138 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1139 luci::get_origin(_p.mean_of_ifm),
1140 luci::get_origin(_p.sub),
1141 luci::get_origin(_p.mean_of_ifm_2),
1142 luci::get_origin(_p.sub_2),
1143 luci::get_origin(_p.square),
1144 luci::get_origin(_p.mean_as_variance),
1145 luci::get_origin(_p.sqrt),
1146 luci::get_origin(_p.add_as_variance),
1147 luci::get_origin(_p.div)};
1148
1149 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1150
1151 replace(_p.div).with(instance_norm);
1152}
1153
1154template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
1155{
1156 auto graph = _p.add_as_terminal->graph();
1157
1158 reshape_gamma_beta();
1159
1160 auto instance_norm = create_inst_norm(graph);
1161
1162 // set origin
1163 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1164 luci::get_origin(_p.mean_of_ifm),
1165 luci::get_origin(_p.sqdiff),
1166 luci::get_origin(_p.mean_as_variance),
1167 luci::get_origin(_p.add_as_variance),
1168 luci::get_origin(_p.rsqrt),
1169 luci::get_origin(_p.mul_as_scaled_ifm),
1170 luci::get_origin(_p.mul_as_scaled_mean),
1171 luci::get_origin(_p.sub),
1172 luci::get_origin(_p.add_as_terminal)};
1173
1174 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1175
1176 replace(_p.add_as_terminal).with(instance_norm);
1177}
1178
1179template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_6>()
1180{
1181 auto graph = _p.add_as_terminal->graph();
1182
1183 reshape_gamma_beta();
1184
1185 auto instance_norm = create_inst_norm(graph);
1186
1187 // set origin
1188 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1189 luci::get_origin(_p.mean_of_ifm),
1190 luci::get_origin(_p.sqdiff),
1191 luci::get_origin(_p.mean_as_variance),
1192 luci::get_origin(_p.add_as_variance),
1193 luci::get_origin(_p.rsqrt),
1194 luci::get_origin(_p.mul_as_scaled_ifm),
1195 luci::get_origin(_p.mul_as_scaled_mean),
1196 luci::get_origin(_p.sub),
1197 luci::get_origin(_p.add_as_terminal)};
1198
1199 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1200
1201 replace(_p.add_as_terminal).with(instance_norm);
1202}
1203
1204template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_7>()
1205{
1206 auto graph = _p.reshape_as_terminal->graph();
1207
1208 reshape_gamma_beta();
1209
1210 auto instance_norm = create_inst_norm(graph);
1211
1212 // set origin
1213 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1214 luci::get_origin(_p.reshape_of_ifm),
1215 luci::get_origin(_p.mean_of_ifm),
1216 luci::get_origin(_p.sub_2),
1217 luci::get_origin(_p.square),
1218 luci::get_origin(_p.mean_as_variance),
1219 luci::get_origin(_p.add_as_variance),
1220 luci::get_origin(_p.rsqrt),
1221 luci::get_origin(_p.mul_gamma),
1222 luci::get_origin(_p.neg_mean),
1223 luci::get_origin(_p.mul_as_scaled_ifm),
1224 luci::get_origin(_p.mul_as_scaled_mean),
1225 luci::get_origin(_p.add_neg_mul),
1226 luci::get_origin(_p.add_as_terminal),
1227 luci::get_origin(_p.reshape_as_terminal)};
1228
1229 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1230
1231 replace(_p.reshape_as_terminal).with(instance_norm);
1232}
1233
1234void FuseInstanceNorm::apply()
1235{
1236 assert(_p.matched());
1237
1238 switch (_p.version())
1239 {
1240 case InstanceNormPattern::PatternVersion::Version_1:
1241 apply<InstanceNormPattern::PatternVersion::Version_1>();
1242 break;
1243 case InstanceNormPattern::PatternVersion::Version_2:
1244 apply<InstanceNormPattern::PatternVersion::Version_2>();
1245 break;
1246 case InstanceNormPattern::PatternVersion::Version_3:
1247 apply<InstanceNormPattern::PatternVersion::Version_3>();
1248 break;
1249 case InstanceNormPattern::PatternVersion::Version_4:
1250 apply<InstanceNormPattern::PatternVersion::Version_4>();
1251 break;
1252 case InstanceNormPattern::PatternVersion::Version_5:
1253 apply<InstanceNormPattern::PatternVersion::Version_5>();
1254 break;
1255 case InstanceNormPattern::PatternVersion::Version_6:
1256 apply<InstanceNormPattern::PatternVersion::Version_6>();
1257 break;
1258 case InstanceNormPattern::PatternVersion::Version_7:
1259 apply<InstanceNormPattern::PatternVersion::Version_7>();
1260 break;
1261
1262 default:
1263 break;
1264 }
1265}
1266
1267} // namespace
1268
1269namespace
1270{
1271
1272class PostFusion final
1273{
1274public:
1275 PostFusion(luci::CircleInstanceNorm *inst_norm) : _inst_norm(inst_norm) {}
1276
1277private:
1278 uint32_t input_channel(void);
1279
1280 luci::CircleConst *match_const_channel(luci::CircleConst *, uint32_t);
1281 bool match_const_gamma_channel(void);
1282 bool match_const_beta_channel(void);
1283
1284public:
1285 bool process(void);
1286
1287private:
1288 luci::CircleInstanceNorm *_inst_norm = nullptr;
1289};
1290
1294uint32_t PostFusion::input_channel(void)
1295{
1296 auto input = dynamic_cast<luci::CircleNode *>(_inst_norm->input());
1297 if (input == nullptr)
1298 return 0;
1299 if (input->shape_status() != luci::ShapeStatus::VALID)
1300 return 0;
1301
1302 auto input_rank = input->rank();
1303 if (input_rank < 1)
1304 return 0;
1305
1306 if (input_rank == 3)
1307 {
1308 // use dim 1
1309 return input->dim(1).value();
1310 }
1311 // assume channel-last
1312 return input->dim(input_rank - 1).value();
1313}
1314
1318luci::CircleConst *PostFusion::match_const_channel(luci::CircleConst *input_const, uint32_t C)
1319{
1320 luci::CircleConst *new_input_const = nullptr;
1321
1322 auto input_chn = input_const->dim(0).value();
1323 if (input_chn == 1 && input_chn != C)
1324 {
1325 float value = input_const->at<loco::DataType::FLOAT32>(0);
1326 auto clone = luci::clone_node(input_const, input_const->graph());
1327
1328 new_input_const = luci::must_cast<luci::CircleConst *>(clone);
1329 new_input_const->rank(1);
1330 new_input_const->dim(0).set(C);
1331 new_input_const->size<loco::DataType::FLOAT32>(C);
1332 for (uint32_t c = 0; c < C; ++c)
1333 new_input_const->at<loco::DataType::FLOAT32>(c) = value;
1334 }
1335
1336 return new_input_const;
1337}
1338
1342bool PostFusion::match_const_gamma_channel(void)
1343{
1344 auto const_as_gamma = dynamic_cast<luci::CircleConst *>(_inst_norm->gamma());
1345 if (const_as_gamma == nullptr)
1346 return false;
1347
1348 auto C = input_channel();
1349 if (C == 0)
1350 return false;
1351
1352 auto new_const_as_gamma = match_const_channel(const_as_gamma, C);
1353 if (new_const_as_gamma == nullptr)
1354 return false;
1355
1356 _inst_norm->gamma(new_const_as_gamma);
1357
1358 return true;
1359}
1360
1364bool PostFusion::match_const_beta_channel(void)
1365{
1366 auto const_as_beta = dynamic_cast<luci::CircleConst *>(_inst_norm->beta());
1367 if (const_as_beta == nullptr)
1368 return false;
1369
1370 auto C = input_channel();
1371 if (C == 0)
1372 return false;
1373
1374 auto new_const_as_beta = match_const_channel(const_as_beta, C);
1375 if (new_const_as_beta == nullptr)
1376 return false;
1377
1378 _inst_norm->beta(new_const_as_beta);
1379
1380 return true;
1381}
1382
1383bool PostFusion::process(void)
1384{
1385 bool changed = false;
1386
1387 if (match_const_gamma_channel())
1388 changed = true;
1389 if (match_const_beta_channel())
1390 changed = true;
1391
1392 return changed;
1393}
1394
1395} // namespace
1396
1397namespace
1398{
1399
1400bool is_add_input_mul_const(luci::CircleAdd *add)
1401{
1402 luci::CircleMul *p_mul = nullptr;
1403 luci::CircleConst *p_const = nullptr;
1404
1405 return luci::fill(&p_mul, &p_const).with_commutative_args_of(add);
1406}
1407
1408bool is_add_input_mul_sub3d(luci::CircleAdd *add)
1409{
1410 luci::CircleMul *p_mul = nullptr;
1411 luci::CircleSub *p_sub = nullptr;
1412
1413 if (!luci::fill(&p_mul, &p_sub).with_commutative_args_of(add))
1414 return false;
1415
1416 auto sub = dynamic_cast<luci::CircleSub *>(add->y());
1417 if (sub == nullptr)
1418 return false;
1419
1420 auto const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
1421 if (const_as_beta == nullptr || const_as_beta->rank() != 3)
1422 return false;
1423
1424 return true;
1425}
1426
1427bool fuse_instance_norm(luci::CircleAdd *add)
1428{
1429 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1;
1430
1431 if (is_add_input_mul_const(add))
1432 pv = InstanceNormPattern::PatternVersion::Version_2;
1433 else if (is_add_input_mul_sub3d(add))
1434 pv = InstanceNormPattern::PatternVersion::Version_6;
1435
1436 InstanceNormPattern pattern(add, pv);
1437 if (pattern.matched())
1438 {
1439 FuseInstanceNorm fuse(pattern);
1440 fuse.apply();
1441 return true;
1442 }
1443
1444 if (pv == InstanceNormPattern::PatternVersion::Version_1)
1445 {
1446 // if Version_1 failed, try with Version_5
1447 pv = InstanceNormPattern::PatternVersion::Version_5;
1448 InstanceNormPattern pattern(add, pv);
1449 if (pattern.matched())
1450 {
1451 FuseInstanceNorm fuse(pattern);
1452 fuse.apply();
1453 return true;
1454 }
1455 }
1456 else if (pv == InstanceNormPattern::PatternVersion::Version_2)
1457 {
1458 // if Version_2 failed, try with Version_3
1459 pv = InstanceNormPattern::PatternVersion::Version_3;
1460 InstanceNormPattern pattern(add, pv);
1461 if (pattern.matched())
1462 {
1463 FuseInstanceNorm fuse(pattern);
1464 fuse.apply();
1465 return true;
1466 }
1467 }
1468
1469 return false;
1470}
1471
1472bool fuse_instance_norm(luci::CircleDiv *div)
1473{
1474 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_4;
1475
1476 InstanceNormPattern pattern(div, pv);
1477 if (pattern.matched())
1478 {
1479 FuseInstanceNorm fuse(pattern);
1480 fuse.apply();
1481 return true;
1482 }
1483
1484 return false;
1485}
1486
1487bool fuse_instance_norm(luci::CircleReshape *reshape)
1488{
1489 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_7;
1490
1491 InstanceNormPattern pattern(reshape, pv);
1492 if (pattern.matched())
1493 {
1494 FuseInstanceNorm fuse(pattern);
1495 fuse.apply();
1496 return true;
1497 }
1498
1499 return false;
1500}
1501
1502bool post_fusion(luci::CircleInstanceNorm *inst_norm)
1503{
1504 PostFusion postfusion(inst_norm);
1505
1506 return postfusion.process();
1507}
1508
1509} // namespace
1510
1511namespace luci
1512{
1513
1515{
1516 bool changed = false;
1517
1518 // Check Version_1, Version_2, Version_3, Version_5, Version_6
1519 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1520 {
1521 auto add = dynamic_cast<luci::CircleAdd *>(node);
1522 if (not add)
1523 continue;
1524
1525 if (fuse_instance_norm(add))
1526 changed = true;
1527 }
1528
1529 // Check Version_4(from DIV) if MUL-ADD pattern is not found
1530 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1531 {
1532 auto div = dynamic_cast<luci::CircleDiv *>(node);
1533 if (not div)
1534 continue;
1535
1536 if (fuse_instance_norm(div))
1537 changed = true;
1538 }
1539
1540 // Check Version_7(from Reshape) if other versions not found
1541 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1542 {
1543 auto reshape = dynamic_cast<luci::CircleReshape *>(node);
1544 if (not reshape)
1545 continue;
1546
1547 if (fuse_instance_norm(reshape))
1548 changed = true;
1549 }
1550
1551 // Post processing of FuseInstanceNorm
1552 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1553 {
1554 auto inst_norm = dynamic_cast<luci::CircleInstanceNorm *>(node);
1555 if (not inst_norm)
1556 continue;
1557
1559 changed = true;
1560 }
1561
1562 return changed;
1563}
1564
1565} // namespace luci
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:55
Graph * graph(void)
Definition Node.h:71
void with(Node *into) const
Definition Node.cpp:66
ADD in Circle.
Definition CircleAdd.h:34
Class to build tensor data.
Definition CircleConst.h:35
const loco::DataTypeImpl< DT >::Type & at(uint32_t n) const
uint32_t size(void) const
DIV in Circle.
Definition CircleDiv.h:37
INSTANCE_NORM in Circle.
loco::Node * input(void) const
MEAN in Circle.
Definition CircleMean.h:32
bool keep_dims(void) const
Definition CircleMean.h:41
loco::Node * input(void) const
Definition CircleMean.h:34
loco::Node * reduction_indices(void) const
Definition CircleMean.h:37
MUL in Circle.
Definition CircleMul.h:34
NEG in Circle.
Definition CircleNeg.h:32
POW in Circle.
Definition CirclePow.h:32
RESHAPE in Circle.
loco::Node * tensor(void) const
RSQRT in Circle.
Definition CircleRsqrt.h:32
SQRT in Circle.
Definition CircleSqrt.h:32
SQUARE in Circle.
SQUARED_DIFFERENCE in Circle.
SUB in Circle.
Definition CircleSub.h:34
#define CHECK_OR_FALSE(condition)
bool is_instance_mean_v1(luci::CircleMean *mean)
bool is_unsqueeze_squeeze_pair(luci::CircleReshape *begin_reshape, luci::CircleReshape *terminal_reshape)
bool is_instance_mean_v2(luci::CircleMean *mean)
bool is_unsqueezed_1D(luci::CircleConst *node, uint32_t depth)
bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
C
Definition infer.py:52
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
T must_cast(FeatureEncoder *node)
A helper dynamic_cast that throws when failed.
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
Definition NodeFiller.h:72
T must_cast(loco::Node *node)
CircleNode * clone_node(const CircleNode *node, loco::Graph *graph)
Return a new cloned CircleNode object with same attributes value of node to graph.
luci::CircleConst * clone(luci::CircleConst *node)
Return cloned object of CircleConst node.
T square(T value)
Definition Loss.h:37
version
Definition setup.py:159
Configuration p
bool run(loco::Graph *g) final
Run the pass.

Function Documentation

◆ is_1D_float32_const()

bool is_1D_float32_const ( const luci::CircleConst node,
uint32_t  channel_size 
)
Returns
true When node has the shape of 1D channel_size

Definition at line 187 of file FuseInstanceNormPass.cpp.

188{
189 if (node->rank() != 1)
190 return false;
191
192 if (node->dim(0).value() != channel_size)
193 return false;
194
195 if (node->dtype() != loco::DataType::FLOAT32)
196 return false;
197
198 if (node->size<loco::DataType::FLOAT32>() != channel_size)
199 return false;
200
201 return true;
202}

References luci::CircleConst::size().

◆ is_instance_mean_v1()

bool is_instance_mean_v1 ( luci::CircleMean mean)

Definition at line 102 of file FuseInstanceNormPass.cpp.

103{
104 //
105 // CHECK 1) input is rank 4
106 //
108 if (input->shape_status() != luci::ShapeStatus::VALID)
109 return false;
110 if (input->rank() != 4)
111 return false;
112
113 //
114 // CHECK 2) 'reduction indices' is CircleConst of value [1,2], that is HW of NHWC
115 //
116 // TODO Support equivalent case, like [-3,-2]
117 // TODO Support non-Const case?
118 // TODO What if input is NCHW format in Circle?
119 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
120 if (not red_indices)
121 return false;
122 if (red_indices->rank() != 1)
123 return false;
124 std::set<int32_t> red_indices_set;
125 {
126 // TODO Currently only support S32, support other types
127 assert(red_indices->dtype() == loco::DataType::S32);
128 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
129 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
130 }
131 if (red_indices_set.size() != 2)
132 return false;
133 if (red_indices_set.find(1) == red_indices_set.end())
134 return false;
135 if (red_indices_set.find(2) == red_indices_set.end())
136 return false;
137
138 //
139 // CHECK 3) keep_dims == true (?)
140 //
141 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
142 // TODO Check this fact, and if true, return true regardless of keep_dims
143 return mean->keep_dims();
144}

References luci::CircleMean::input(), luci::CircleMean::keep_dims(), loco::must_cast(), luci::CircleMean::reduction_indices(), and luci::VALID.

◆ is_instance_mean_v2()

bool is_instance_mean_v2 ( luci::CircleMean mean)

Definition at line 146 of file FuseInstanceNormPass.cpp.

147{
148 //
149 // CHECK 1) input is rank 3
150 //
152 if (input->shape_status() != luci::ShapeStatus::VALID)
153 return false;
154 if (input->rank() != 3)
155 return false;
156
157 //
158 // CHECK 2) 'reduction indices' is CircleConst of value [2], that is last dim of rank 3
159 //
160 // TODO Support non-Const case?
161 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
162 if (not red_indices)
163 return false;
164 if (red_indices->rank() != 1)
165 return false;
166 std::set<int32_t> red_indices_set;
167 {
168 // TODO Currently only support S32, support other types
169 assert(red_indices->dtype() == loco::DataType::S32);
170 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
171 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
172 }
173 if (red_indices_set.size() != 1)
174 return false;
175 if (red_indices_set.find(2) == red_indices_set.end())
176 return false;
177
178 //
179 // CHECK 3) keep_dims == true (?)
180 //
181 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
182 // TODO Check this fact, and if true, return true regardless of keep_dims
183 return mean->keep_dims();
184}

References luci::CircleMean::input(), luci::CircleMean::keep_dims(), loco::must_cast(), luci::CircleMean::reduction_indices(), and luci::VALID.

◆ is_unsqueeze_squeeze_pair()

bool is_unsqueeze_squeeze_pair ( luci::CircleReshape begin_reshape,
luci::CircleReshape terminal_reshape 
)
Returns
true if the provided begin_reshape Reshape op adds 1 dimension and terminal_reshape Reshape op removes it (the result is neutral for further processing)

Definition at line 64 of file FuseInstanceNormPass.cpp.

66{
67 auto const begin_reshape_ifm = dynamic_cast<luci::CircleNode *>(begin_reshape->tensor());
68 CHECK_OR_FALSE(begin_reshape_ifm);
69
70 // check last axis
71 CHECK_OR_FALSE((begin_reshape_ifm->rank() + 1) == begin_reshape->rank());
72
73 // check unchanged part of begin_shape
74 for (uint32_t axis = 0; axis < begin_reshape_ifm->rank(); ++axis)
75 {
76 // skip dynamic cases
77 CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).known() && begin_reshape->dim(axis).known());
78 CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).value() == begin_reshape->dim(axis).value());
79 }
80 // check last axis
81 CHECK_OR_FALSE(begin_reshape->dim(begin_reshape->rank() - 1) == 1);
82
83 auto const terminal_reshape_ifm = dynamic_cast<luci::CircleNode *>(terminal_reshape->tensor());
84 CHECK_OR_FALSE(terminal_reshape_ifm);
85
86 CHECK_OR_FALSE(terminal_reshape_ifm->rank() == terminal_reshape->rank() + 1);
87
88 // check last axis
89 CHECK_OR_FALSE(terminal_reshape_ifm->dim(begin_reshape->rank() - 1) == 1);
90
91 // check unchanged part of terminal_reshape
92 for (uint32_t axis = 0; axis < terminal_reshape->rank(); ++axis)
93 {
94 // skip dynamic cases
95 CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).known() && terminal_reshape->dim(axis).known());
96 CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).value() == terminal_reshape->dim(axis).value());
97 }
98
99 return true;
100}

References CHECK_OR_FALSE, and luci::CircleReshape::tensor().

◆ is_unsqueezed_1D()

bool is_unsqueezed_1D ( luci::CircleConst node,
uint32_t  depth 
)
Returns
true When node has shape with one dim other than 1 (like '1 x .. x 1 x depth' or '1 x .. x depth' x 1)

Definition at line 38 of file FuseInstanceNormPass.cpp.

39{
40 const auto rank = node->rank();
41 std::optional<uint32_t> depth_axis;
42 for (uint32_t axis = 0; axis < rank; ++axis)
43 {
44 if (node->dim(axis).value() != 1)
45 {
46 // only one axis can be other than 1
47 if (depth_axis.has_value())
48 {
49 return false;
50 }
51 depth_axis = axis;
52 }
53 }
54 if (!depth_axis.has_value())
55 {
56 return false;
57 }
58 return node->dim(depth_axis.value()).value() == depth;
59}