38{
39 const auto rank = node->rank();
40 std::optional<uint32_t> depth_axis;
41 for (uint32_t axis = 0; axis < rank; ++axis)
42 {
43 if (node->dim(axis).value() != 1)
44 {
45
46 if (depth_axis.has_value())
47 {
48 return false;
49 }
50 depth_axis = axis;
51 }
52 }
53 if (!depth_axis.has_value())
54 {
55 return false;
56 }
57 return node->dim(depth_axis.value()).value() == depth;
58}
59
65{
68
69
70 CHECK_OR_FALSE((begin_reshape_ifm->rank() + 1) == begin_reshape->rank());
71
72
73 for (uint32_t axis = 0; axis < begin_reshape_ifm->rank(); ++axis)
74 {
75
76 CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).known() && begin_reshape->dim(axis).known());
77 CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).value() == begin_reshape->dim(axis).value());
78 }
79
80 CHECK_OR_FALSE(begin_reshape->dim(begin_reshape->rank() - 1) == 1);
81
84
85 CHECK_OR_FALSE(terminal_reshape_ifm->rank() == terminal_reshape->rank() + 1);
86
87
88 CHECK_OR_FALSE(terminal_reshape_ifm->dim(begin_reshape->rank() - 1) == 1);
89
90
91 for (uint32_t axis = 0; axis < terminal_reshape->rank(); ++axis)
92 {
93
94 CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).known() && terminal_reshape->dim(axis).known());
95 CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).value() == terminal_reshape->dim(axis).value());
96 }
97
98 return true;
99}
100
102{
103
104
105
108 return false;
109 if (
input->rank() != 4)
110 return false;
111
112
113
114
115
116
117
119 if (not red_indices)
120 return false;
121 if (red_indices->rank() != 1)
122 return false;
123 std::set<int32_t> red_indices_set;
124 {
125
126 assert(red_indices->dtype() == loco::DataType::S32);
127 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
128 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
129 }
130 if (red_indices_set.size() != 2)
131 return false;
132 if (red_indices_set.find(1) == red_indices_set.end())
133 return false;
134 if (red_indices_set.find(2) == red_indices_set.end())
135 return false;
136
137
138
139
140
141
143}
144
146{
147
148
149
152 return false;
153 if (
input->rank() != 3)
154 return false;
155
156
157
158
159
161 if (not red_indices)
162 return false;
163 if (red_indices->rank() != 1)
164 return false;
165 std::set<int32_t> red_indices_set;
166 {
167
168 assert(red_indices->dtype() == loco::DataType::S32);
169 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
170 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
171 }
172 if (red_indices_set.size() != 1)
173 return false;
174 if (red_indices_set.find(2) == red_indices_set.end())
175 return false;
176
177
178
179
180
181
183}
184
187{
188 if (node->rank() != 1)
189 return false;
190
191 if (node->dim(0).value() != channel_size)
192 return false;
193
194 if (node->dtype() != loco::DataType::FLOAT32)
195 return false;
196
197 if (node->
size<loco::DataType::FLOAT32>() != channel_size)
198 return false;
199
200 return true;
201}
202
203
204namespace
205{
206
482class InstanceNormPattern final
483{
484public:
485 enum PatternVersion
486 {
487 Version_Unknown,
488 Version_1,
489 Version_2,
490 Version_3,
491 Version_4,
492 Version_5,
493 Version_6,
494 Version_7,
495 };
496
498 {
499 assert(candidate);
500 add_as_terminal = candidate;
501 _pv = pv;
502 }
503
505 {
506 assert(candidate);
508 _pv = pv;
509 }
510
512 {
513 assert(candidate);
514 reshape_as_terminal = candidate;
515 _pv = pv;
516 }
517
518private:
519 bool condition_common_1_5(uint32_t ifm_channel_depth);
520 bool condition_common_3_4();
521
522private:
523 template <enum PatternVersion> bool match();
524
525public:
526 bool matched();
527 bool matched() const { return _matched; }
528
529 PatternVersion
version()
const {
return _pv; }
530
531public:
532
560
561private:
562 bool _matched = false;
563 PatternVersion _pv;
564};
565
566bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
567{
570
572 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
573
574 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
575
576 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
577
579
582
588
592
593 return true;
594}
595
596bool InstanceNormPattern::condition_common_3_4()
597{
598
599 ifm = sub->x();
601
605
609
610
612 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
613
614 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
615
618
621
625
629
633 luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
636
637 return true;
638}
639
640template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
641{
644
649 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
650
652
654
656
659
663 .with_commutative_args_of(mul_as_scaled_mean));
666
667 _matched = true;
668 return true;
669}
670
671template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_2>()
672{
675
678
679 ifm = sub->x();
681
685 uint32_t ifm_channel_depth = ifm_node->dim(3).value();
686
689
691
694
697
700 CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
701
704
706 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
707 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
708
709 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
710
712
715
719 luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
722
723
726
727 _matched = true;
728 return true;
729}
730
731template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_3>()
732{
736
738
739 _matched = true;
740 return true;
741}
742
744{
746 const_one->dtype(loco::DataType::FLOAT32);
747 const_one->rank(1);
748 const_one->
size<loco::DataType::FLOAT32>(1);
749 const_one->at<loco::DataType::FLOAT32>(0) = value;
750 return const_one;
751}
752
753template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_4>()
754{
757
759
760 assert(const_as_gamma == nullptr);
761 assert(const_as_beta == nullptr);
762 assert(mul_gamma == nullptr);
763 assert(add_as_terminal == nullptr);
764
765
767 const_as_gamma = make_const_one(graph, 1.0f);
768 const_as_beta = make_const_one(graph, 0.0f);
769 const_as_gamma->name(
div->name() +
"/gamma");
770 const_as_beta->name(
div->name() +
"/beta");
771
772 _matched = true;
773 return true;
774}
775
776template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
777{
780
785 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
786
788
791
795 .with_commutative_args_of(mul_as_scaled_mean));
798
799
800
801 auto graph = add_as_terminal->graph();
802 const_as_gamma = make_const_one(graph, 1.0f);
803 const_as_gamma->name(add_as_terminal->name() + "/gamma");
804
805 _matched = true;
806 return true;
807}
808
809template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_6>()
810{
813
818
821
823 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
824
825 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
826
827 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
828
830
833
839
840
841 uint32_t input_channel = ifm_circle->dim(1).value();
842 uint32_t input_last_dim = ifm_circle->dim(2).value();
847 const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel &&
848 (const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim));
849
852
856 .with_commutative_args_of(mul_as_scaled_mean));
859
860
861
862 auto graph = add_as_terminal->graph();
863 const_as_gamma = make_const_one(graph, 1.0f);
864 const_as_gamma->name(add_as_terminal->name() + "/gamma");
865
866 _matched = true;
867 return true;
868}
869
870template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_7>()
871{
872 add_as_terminal =
dynamic_cast<luci::CircleAdd *
>(reshape_as_terminal->tensor());
874
876 luci::fill(&mul_as_scaled_ifm, &add_neg_mul).with_commutative_args_of(add_as_terminal));
878 luci::fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
879
880 mul_as_scaled_mean =
dynamic_cast<luci::CircleMul *
>(add_neg_mul->x());
882
885
888
890 luci::fill(&mul_gamma_should_be, &neg_should_be).with_commutative_args_of(mul_as_scaled_mean));
891
894
897
901
902 ifm = reshape_of_ifm->tensor();
905
909 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
910
914
917
920
922 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
924
925 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
926
927 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
928
931
934
937
940
942
943 _matched = true;
944 return true;
945}
946
947bool InstanceNormPattern::matched()
948{
949 if (_matched)
950 return true;
951
952
953
954 switch (_pv)
955 {
956 case PatternVersion::Version_1:
957 return match<PatternVersion::Version_1>();
958 case PatternVersion::Version_2:
959 return match<PatternVersion::Version_2>();
960 case PatternVersion::Version_3:
961 return match<PatternVersion::Version_3>();
962 case PatternVersion::Version_4:
963 return match<PatternVersion::Version_4>();
964 case PatternVersion::Version_5:
965 return match<PatternVersion::Version_5>();
966 case PatternVersion::Version_6:
967 return match<PatternVersion::Version_6>();
968 case PatternVersion::Version_7:
969 return match<PatternVersion::Version_7>();
970
971 default:
972 break;
973 }
974
975 throw std::runtime_error("Invalid InstanceNorm PatternVersion.");
976}
977
978#undef CHECK_OR_FALSE
979
996class FuseInstanceNorm final
997{
998public:
999 FuseInstanceNorm(
const InstanceNormPattern &
p) : _p(
p) {}
1000
1001public:
1003
1004private:
1005 template <InstanceNormPattern::PatternVersion>
void apply(
void);
1006
1007private:
1008 void reshape_gamma_beta(void);
1010
1011private:
1012 const InstanceNormPattern &_p;
1013};
1014
1015void FuseInstanceNorm::reshape_gamma_beta()
1016{
1017
1018 {
1019 _p.const_as_gamma->rank(1);
1020 _p.const_as_gamma->dim(0).set(_p.const_as_gamma->size<loco::DataType::FLOAT32>());
1021 _p.const_as_beta->rank(1);
1022 _p.const_as_beta->dim(0).set(_p.const_as_beta->size<loco::DataType::FLOAT32>());
1023
1026 }
1027}
1028
1030{
1031
1033 instance_norm->
input(_p.ifm);
1034 instance_norm->gamma(_p.const_as_gamma);
1035 instance_norm->beta(_p.const_as_beta);
1036 float epsilon = _p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
1037 instance_norm->epsilon(epsilon);
1038 if (_p.add_as_terminal != nullptr)
1039 {
1040 instance_norm->fusedActivationFunction(_p.add_as_terminal->fusedActivationFunction());
1041
1042 instance_norm->name("FusedInstanceNorm/" + _p.add_as_terminal->name());
1043 }
1044 else
1045 {
1046
1047 assert(_p.div != nullptr);
1048 instance_norm->fusedActivationFunction(_p.div->fusedActivationFunction());
1049 instance_norm->name("FusedInstanceNorm/" + _p.div->name());
1050 }
1051
1052 return instance_norm;
1053}
1054
1055template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_1>()
1056{
1057 auto graph = _p.add_as_terminal->graph();
1058
1059 reshape_gamma_beta();
1060
1061 auto instance_norm = create_inst_norm(graph);
1062
1063
1064 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1065 luci::get_origin(_p.mean_of_ifm),
1066 luci::get_origin(_p.sqdiff),
1067 luci::get_origin(_p.mean_as_variance),
1068 luci::get_origin(_p.add_as_variance),
1069 luci::get_origin(_p.rsqrt),
1070 luci::get_origin(_p.mul_gamma),
1071 luci::get_origin(_p.mul_as_scaled_ifm),
1072 luci::get_origin(_p.mul_as_scaled_mean),
1073 luci::get_origin(_p.sub),
1074 luci::get_origin(_p.add_as_terminal)};
1075
1077
1079}
1080
1081template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_2>()
1082{
1083 auto graph = _p.add_as_terminal->graph();
1084
1085 auto instance_norm = create_inst_norm(graph);
1086
1087
1088 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1089 luci::get_origin(_p.mean_of_ifm),
1090 luci::get_origin(_p.sqdiff),
1091 luci::get_origin(_p.mean_as_variance),
1092 luci::get_origin(_p.add_as_variance),
1093 luci::get_origin(_p.pow),
1094 luci::get_origin(_p.sub),
1095 luci::get_origin(_p.div),
1096 luci::get_origin(_p.mul_gamma),
1097 luci::get_origin(_p.add_as_terminal)};
1098
1100
1102}
1103
1104template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_3>()
1105{
1106 auto graph = _p.add_as_terminal->graph();
1107
1108 reshape_gamma_beta();
1109
1110 auto instance_norm = create_inst_norm(graph);
1111
1112
1113 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1114 luci::get_origin(_p.mean_of_ifm),
1115 luci::get_origin(_p.sub),
1116 luci::get_origin(_p.mean_of_ifm_2),
1117 luci::get_origin(_p.sub_2),
1118 luci::get_origin(_p.square),
1119 luci::get_origin(_p.mean_as_variance),
1120 luci::get_origin(_p.sqrt),
1121 luci::get_origin(_p.add_as_variance),
1122 luci::get_origin(_p.div),
1123 luci::get_origin(_p.mul_gamma),
1124 luci::get_origin(_p.add_as_terminal)};
1125
1127
1129}
1130
1131template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_4>()
1132{
1133 auto graph = _p.div->graph();
1134
1135 auto instance_norm = create_inst_norm(graph);
1136
1137
1138 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1139 luci::get_origin(_p.mean_of_ifm),
1140 luci::get_origin(_p.sub),
1141 luci::get_origin(_p.mean_of_ifm_2),
1142 luci::get_origin(_p.sub_2),
1143 luci::get_origin(_p.square),
1144 luci::get_origin(_p.mean_as_variance),
1145 luci::get_origin(_p.sqrt),
1146 luci::get_origin(_p.add_as_variance),
1147 luci::get_origin(_p.div)};
1148
1150
1152}
1153
1154template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
1155{
1156 auto graph = _p.add_as_terminal->graph();
1157
1158 reshape_gamma_beta();
1159
1160 auto instance_norm = create_inst_norm(graph);
1161
1162
1163 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1164 luci::get_origin(_p.mean_of_ifm),
1165 luci::get_origin(_p.sqdiff),
1166 luci::get_origin(_p.mean_as_variance),
1167 luci::get_origin(_p.add_as_variance),
1168 luci::get_origin(_p.rsqrt),
1169 luci::get_origin(_p.mul_as_scaled_ifm),
1170 luci::get_origin(_p.mul_as_scaled_mean),
1171 luci::get_origin(_p.sub),
1172 luci::get_origin(_p.add_as_terminal)};
1173
1175
1177}
1178
1179template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_6>()
1180{
1181 auto graph = _p.add_as_terminal->graph();
1182
1183 reshape_gamma_beta();
1184
1185 auto instance_norm = create_inst_norm(graph);
1186
1187
1188 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1189 luci::get_origin(_p.mean_of_ifm),
1190 luci::get_origin(_p.sqdiff),
1191 luci::get_origin(_p.mean_as_variance),
1192 luci::get_origin(_p.add_as_variance),
1193 luci::get_origin(_p.rsqrt),
1194 luci::get_origin(_p.mul_as_scaled_ifm),
1195 luci::get_origin(_p.mul_as_scaled_mean),
1196 luci::get_origin(_p.sub),
1197 luci::get_origin(_p.add_as_terminal)};
1198
1200
1202}
1203
1204template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_7>()
1205{
1206 auto graph = _p.reshape_as_terminal->graph();
1207
1208 reshape_gamma_beta();
1209
1210 auto instance_norm = create_inst_norm(graph);
1211
1212
1213 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1214 luci::get_origin(_p.reshape_of_ifm),
1215 luci::get_origin(_p.mean_of_ifm),
1216 luci::get_origin(_p.sub_2),
1217 luci::get_origin(_p.square),
1218 luci::get_origin(_p.mean_as_variance),
1219 luci::get_origin(_p.add_as_variance),
1220 luci::get_origin(_p.rsqrt),
1221 luci::get_origin(_p.mul_gamma),
1222 luci::get_origin(_p.neg_mean),
1223 luci::get_origin(_p.mul_as_scaled_ifm),
1224 luci::get_origin(_p.mul_as_scaled_mean),
1225 luci::get_origin(_p.add_neg_mul),
1226 luci::get_origin(_p.add_as_terminal),
1227 luci::get_origin(_p.reshape_as_terminal)};
1228
1230
1231 replace(_p.reshape_as_terminal).
with(instance_norm);
1232}
1233
1234void FuseInstanceNorm::apply()
1235{
1236 assert(_p.matched());
1237
1238 switch (_p.version())
1239 {
1240 case InstanceNormPattern::PatternVersion::Version_1:
1241 apply<InstanceNormPattern::PatternVersion::Version_1>();
1242 break;
1243 case InstanceNormPattern::PatternVersion::Version_2:
1244 apply<InstanceNormPattern::PatternVersion::Version_2>();
1245 break;
1246 case InstanceNormPattern::PatternVersion::Version_3:
1247 apply<InstanceNormPattern::PatternVersion::Version_3>();
1248 break;
1249 case InstanceNormPattern::PatternVersion::Version_4:
1250 apply<InstanceNormPattern::PatternVersion::Version_4>();
1251 break;
1252 case InstanceNormPattern::PatternVersion::Version_5:
1253 apply<InstanceNormPattern::PatternVersion::Version_5>();
1254 break;
1255 case InstanceNormPattern::PatternVersion::Version_6:
1256 apply<InstanceNormPattern::PatternVersion::Version_6>();
1257 break;
1258 case InstanceNormPattern::PatternVersion::Version_7:
1259 apply<InstanceNormPattern::PatternVersion::Version_7>();
1260 break;
1261
1262 default:
1263 break;
1264 }
1265}
1266
1267}
1268
1269namespace
1270{
1271
1272class PostFusion final
1273{
1274public:
1276
1277private:
1278 uint32_t input_channel(void);
1279
1281 bool match_const_gamma_channel(void);
1282 bool match_const_beta_channel(void);
1283
1284public:
1286
1287private:
1289};
1290
1294uint32_t PostFusion::input_channel(void)
1295{
1297 if (input == nullptr)
1298 return 0;
1300 return 0;
1301
1302 auto input_rank =
input->rank();
1303 if (input_rank < 1)
1304 return 0;
1305
1306 if (input_rank == 3)
1307 {
1308
1309 return input->dim(1).value();
1310 }
1311
1312 return input->dim(input_rank - 1).value();
1313}
1314
1319{
1321
1322 auto input_chn = input_const->dim(0).value();
1323 if (input_chn == 1 && input_chn != C)
1324 {
1325 float value = input_const->
at<loco::DataType::FLOAT32>(0);
1327
1329 new_input_const->rank(1);
1330 new_input_const->dim(0).set(C);
1331 new_input_const->
size<loco::DataType::FLOAT32>(
C);
1332 for (uint32_t c = 0; c <
C; ++c)
1333 new_input_const->
at<loco::DataType::FLOAT32>(c) = value;
1334 }
1335
1336 return new_input_const;
1337}
1338
1342bool PostFusion::match_const_gamma_channel(void)
1343{
1345 if (const_as_gamma == nullptr)
1346 return false;
1347
1348 auto C = input_channel();
1349 if (C == 0)
1350 return false;
1351
1352 auto new_const_as_gamma = match_const_channel(const_as_gamma, C);
1353 if (new_const_as_gamma == nullptr)
1354 return false;
1355
1356 _inst_norm->gamma(new_const_as_gamma);
1357
1358 return true;
1359}
1360
1364bool PostFusion::match_const_beta_channel(void)
1365{
1367 if (const_as_beta == nullptr)
1368 return false;
1369
1370 auto C = input_channel();
1371 if (C == 0)
1372 return false;
1373
1374 auto new_const_as_beta = match_const_channel(const_as_beta, C);
1375 if (new_const_as_beta == nullptr)
1376 return false;
1377
1378 _inst_norm->beta(new_const_as_beta);
1379
1380 return true;
1381}
1382
1383bool PostFusion::process(void)
1384{
1385 bool changed = false;
1386
1387 if (match_const_gamma_channel())
1388 changed = true;
1389 if (match_const_beta_channel())
1390 changed = true;
1391
1392 return changed;
1393}
1394
1395}
1396
1397namespace
1398{
1399
1401{
1404
1405 return luci::fill(&p_mul, &p_const).with_commutative_args_of(add);
1406}
1407
1409{
1412
1413 if (!
luci::fill(&p_mul, &p_sub).with_commutative_args_of(add))
1414 return false;
1415
1417 if (sub == nullptr)
1418 return false;
1419
1421 if (const_as_beta == nullptr || const_as_beta->rank() != 3)
1422 return false;
1423
1424 return true;
1425}
1426
1428{
1429 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1;
1430
1431 if (is_add_input_mul_const(add))
1432 pv = InstanceNormPattern::PatternVersion::Version_2;
1433 else if (is_add_input_mul_sub3d(add))
1434 pv = InstanceNormPattern::PatternVersion::Version_6;
1435
1436 InstanceNormPattern pattern(add, pv);
1437 if (pattern.matched())
1438 {
1439 FuseInstanceNorm fuse(pattern);
1440 fuse.apply();
1441 return true;
1442 }
1443
1444 if (pv == InstanceNormPattern::PatternVersion::Version_1)
1445 {
1446
1447 pv = InstanceNormPattern::PatternVersion::Version_5;
1448 InstanceNormPattern pattern(add, pv);
1449 if (pattern.matched())
1450 {
1451 FuseInstanceNorm fuse(pattern);
1452 fuse.apply();
1453 return true;
1454 }
1455 }
1456 else if (pv == InstanceNormPattern::PatternVersion::Version_2)
1457 {
1458
1459 pv = InstanceNormPattern::PatternVersion::Version_3;
1460 InstanceNormPattern pattern(add, pv);
1461 if (pattern.matched())
1462 {
1463 FuseInstanceNorm fuse(pattern);
1464 fuse.apply();
1465 return true;
1466 }
1467 }
1468
1469 return false;
1470}
1471
1473{
1474 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_4;
1475
1476 InstanceNormPattern pattern(div, pv);
1477 if (pattern.matched())
1478 {
1479 FuseInstanceNorm fuse(pattern);
1480 fuse.apply();
1481 return true;
1482 }
1483
1484 return false;
1485}
1486
1488{
1489 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_7;
1490
1491 InstanceNormPattern pattern(reshape, pv);
1492 if (pattern.matched())
1493 {
1494 FuseInstanceNorm fuse(pattern);
1495 fuse.apply();
1496 return true;
1497 }
1498
1499 return false;
1500}
1501
1503{
1504 PostFusion postfusion(inst_norm);
1505
1506 return postfusion.process();
1507}
1508
1509}
1510
1512{
1513
1515{
1516 bool changed = false;
1517
1518
1520 {
1523 continue;
1524
1526 changed = true;
1527 }
1528
1529
1531 {
1534 continue;
1535
1537 changed = true;
1538 }
1539
1540
1542 {
1545 continue;
1546
1548 changed = true;
1549 }
1550
1551
1553 {
1556 continue;
1557
1559 changed = true;
1560 }
1561
1562 return changed;
1563}
1564
1565}
Logical unit of computation.
void with(Node *into) const
Class to build tensor data.
const loco::DataTypeImpl< DT >::Type & at(uint32_t n) const
uint32_t size(void) const
loco::Node * input(void) const
bool keep_dims(void) const
loco::Node * input(void) const
loco::Node * reduction_indices(void) const
loco::Node * tensor(void) const
SQUARED_DIFFERENCE in Circle.
#define CHECK_OR_FALSE(condition)
bool is_instance_mean_v1(luci::CircleMean *mean)
bool is_unsqueeze_squeeze_pair(luci::CircleReshape *begin_reshape, luci::CircleReshape *terminal_reshape)
bool is_instance_mean_v2(luci::CircleMean *mean)
bool is_unsqueezed_1D(luci::CircleConst *node, uint32_t depth)
bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
T must_cast(FeatureEncoder *node)
A helper dynamic_cast that throws when failed.
std::vector< Node * > output_nodes(Graph *)
Subst< SubstQualifier::Default > replace(Node *node)
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
T must_cast(loco::Node *node)
CircleNode * clone_node(const CircleNode *node, loco::Graph *graph)
Return a new cloned CircleNode object with same attributes value of node to graph.
luci::CircleConst * clone(luci::CircleConst *node)
Return cloned object of CircleConst node.
bool run(loco::Graph *g) final
Run the pass.