450{
453
455 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
456
457 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
458
459 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
460
462
465
471
475
476 return true;
477}
478
479bool InstanceNormPattern::condition_common_3_4()
480{
481
482 ifm = sub->x();
484
488
492
493
495 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
496
497 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
498
501
504
508
512
516 luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
519
520 return true;
521}
522
523template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
524{
527
528 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
532 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
533
535
537
539
542
546 .with_commutative_args_of(mul_as_scaled_mean));
549
550 _matched = true;
551 return true;
552}
553
554template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_2>()
555{
558
561
562 ifm = sub->x();
564
568 uint32_t ifm_channel_depth = ifm_node->dim(3).value();
569
572
574
577
580
583 CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
584
587
589 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
590 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
591
592 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
593
595
598
602 luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
605
606
609
610 _matched = true;
611 return true;
612}
613
614template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_3>()
615{
619
621
622 _matched = true;
623 return true;
624}
625
627{
629 const_one->dtype(loco::DataType::FLOAT32);
630 const_one->rank(1);
631 const_one->
size<loco::DataType::FLOAT32>(1);
632 const_one->at<loco::DataType::FLOAT32>(0) = value;
633 return const_one;
634}
635
636template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_4>()
637{
640
642
643 assert(const_as_gamma == nullptr);
644 assert(const_as_beta == nullptr);
645 assert(mul_gamma == nullptr);
646 assert(add_as_terminal == nullptr);
647
648
650 const_as_gamma = make_const_one(graph, 1.0f);
651 const_as_beta = make_const_one(graph, 0.0f);
652 const_as_gamma->name(
div->name() +
"/gamma");
653 const_as_beta->name(
div->name() +
"/beta");
654
655 _matched = true;
656 return true;
657}
658
659template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
660{
663
664 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
668 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
669
671
674
678 .with_commutative_args_of(mul_as_scaled_mean));
681
682
683
684 auto graph = add_as_terminal->graph();
685 const_as_gamma = make_const_one(graph, 1.0f);
686 const_as_gamma->name(add_as_terminal->name() + "/gamma");
687
688 _matched = true;
689 return true;
690}
691
692template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_6>()
693{
696
697 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
701
704
706 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
707
708 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
709
710 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
711
713
716
722
723
724 uint32_t input_channel = ifm_circle->dim(1).value();
725 uint32_t input_last_dim = ifm_circle->dim(2).value();
730 const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel &&
731 (const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim));
732
735
739 .with_commutative_args_of(mul_as_scaled_mean));
742
743
744
745 auto graph = add_as_terminal->graph();
746 const_as_gamma = make_const_one(graph, 1.0f);
747 const_as_gamma->name(add_as_terminal->name() + "/gamma");
748
749 _matched = true;
750 return true;
751}
752
753bool InstanceNormPattern::matched()
754{
755 if (_matched)
756 return true;
757
758
759
760 switch (_pv)
761 {
762 case PatternVersion::Version_1:
763 return match<PatternVersion::Version_1>();
764 case PatternVersion::Version_2:
765 return match<PatternVersion::Version_2>();
766 case PatternVersion::Version_3:
767 return match<PatternVersion::Version_3>();
768 case PatternVersion::Version_4:
769 return match<PatternVersion::Version_4>();
770 case PatternVersion::Version_5:
771 return match<PatternVersion::Version_5>();
772 case PatternVersion::Version_6:
773 return match<PatternVersion::Version_6>();
774
775 default:
776 break;
777 }
778
779 throw std::runtime_error("Invalid InstanceNorm PatternVersion.");
780}
781
782#undef CHECK_OR_FALSE
783
800class FuseInstanceNorm final
801{
802public:
803 FuseInstanceNorm(const InstanceNormPattern &p) : _p(p) {}
804
805public:
807
808private:
809 template <InstanceNormPattern::PatternVersion>
void apply(
void);
810
811private:
812 void reshape_gamma_beta(void);
814
815private:
816 const InstanceNormPattern &_p;
817};
818
819void FuseInstanceNorm::reshape_gamma_beta()
820{
821
822 {
823 _p.const_as_gamma->rank(1);
824 _p.const_as_gamma->dim(0).set(_p.const_as_gamma->size<loco::DataType::FLOAT32>());
825 _p.const_as_beta->rank(1);
826 _p.const_as_beta->dim(0).set(_p.const_as_beta->size<loco::DataType::FLOAT32>());
827
830 }
831}
832
834{
835
837 instance_norm->
input(_p.ifm);
838 instance_norm->gamma(_p.const_as_gamma);
839 instance_norm->beta(_p.const_as_beta);
840 float epsilon = _p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
841 instance_norm->epsilon(epsilon);
842 if (_p.add_as_terminal != nullptr)
843 {
844 instance_norm->fusedActivationFunction(_p.add_as_terminal->fusedActivationFunction());
845
846 instance_norm->name("FusedInstanceNorm/" + _p.add_as_terminal->name());
847 }
848 else
849 {
850
851 assert(_p.div != nullptr);
852 instance_norm->fusedActivationFunction(_p.div->fusedActivationFunction());
853 instance_norm->name("FusedInstanceNorm/" + _p.div->name());
854 }
855
856 return instance_norm;
857}
858
859template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_1>()
860{
861 auto graph = _p.add_as_terminal->graph();
862
863 reshape_gamma_beta();
864
865 auto instance_norm = create_inst_norm(graph);
866
867
868 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
869 luci::get_origin(_p.mean_of_ifm),
870 luci::get_origin(_p.sqdiff),
871 luci::get_origin(_p.mean_as_variance),
872 luci::get_origin(_p.add_as_variance),
873 luci::get_origin(_p.rsqrt),
874 luci::get_origin(_p.mul_gamma),
875 luci::get_origin(_p.mul_as_scaled_ifm),
876 luci::get_origin(_p.mul_as_scaled_mean),
877 luci::get_origin(_p.sub),
878 luci::get_origin(_p.add_as_terminal)};
879
881
883}
884
885template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_2>()
886{
887 auto graph = _p.add_as_terminal->graph();
888
889 auto instance_norm = create_inst_norm(graph);
890
891
892 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
893 luci::get_origin(_p.mean_of_ifm),
894 luci::get_origin(_p.sqdiff),
895 luci::get_origin(_p.mean_as_variance),
896 luci::get_origin(_p.add_as_variance),
897 luci::get_origin(_p.pow),
898 luci::get_origin(_p.sub),
899 luci::get_origin(_p.div),
900 luci::get_origin(_p.mul_gamma),
901 luci::get_origin(_p.add_as_terminal)};
902
904
906}
907
908template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_3>()
909{
910 auto graph = _p.add_as_terminal->graph();
911
912 reshape_gamma_beta();
913
914 auto instance_norm = create_inst_norm(graph);
915
916
917 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
918 luci::get_origin(_p.mean_of_ifm),
919 luci::get_origin(_p.sub),
920 luci::get_origin(_p.mean_of_ifm_2),
921 luci::get_origin(_p.sub_2),
922 luci::get_origin(_p.square),
923 luci::get_origin(_p.mean_as_variance),
924 luci::get_origin(_p.sqrt),
925 luci::get_origin(_p.add_as_variance),
926 luci::get_origin(_p.div),
927 luci::get_origin(_p.mul_gamma),
928 luci::get_origin(_p.add_as_terminal)};
929
931
933}
934
935template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_4>()
936{
937 auto graph = _p.div->graph();
938
939 auto instance_norm = create_inst_norm(graph);
940
941
942 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
943 luci::get_origin(_p.mean_of_ifm),
944 luci::get_origin(_p.sub),
945 luci::get_origin(_p.mean_of_ifm_2),
946 luci::get_origin(_p.sub_2),
947 luci::get_origin(_p.square),
948 luci::get_origin(_p.mean_as_variance),
949 luci::get_origin(_p.sqrt),
950 luci::get_origin(_p.add_as_variance),
951 luci::get_origin(_p.div)};
952
954
956}
957
958template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
959{
960 auto graph = _p.add_as_terminal->graph();
961
962 reshape_gamma_beta();
963
964 auto instance_norm = create_inst_norm(graph);
965
966
967 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
968 luci::get_origin(_p.mean_of_ifm),
969 luci::get_origin(_p.sqdiff),
970 luci::get_origin(_p.mean_as_variance),
971 luci::get_origin(_p.add_as_variance),
972 luci::get_origin(_p.rsqrt),
973 luci::get_origin(_p.mul_as_scaled_ifm),
974 luci::get_origin(_p.mul_as_scaled_mean),
975 luci::get_origin(_p.sub),
976 luci::get_origin(_p.add_as_terminal)};
977
979
981}
982
983template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_6>()
984{
985 auto graph = _p.add_as_terminal->graph();
986
987 reshape_gamma_beta();
988
989 auto instance_norm = create_inst_norm(graph);
990
991
992 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
993 luci::get_origin(_p.mean_of_ifm),
994 luci::get_origin(_p.sqdiff),
995 luci::get_origin(_p.mean_as_variance),
996 luci::get_origin(_p.add_as_variance),
997 luci::get_origin(_p.rsqrt),
998 luci::get_origin(_p.mul_as_scaled_ifm),
999 luci::get_origin(_p.mul_as_scaled_mean),
1000 luci::get_origin(_p.sub),
1001 luci::get_origin(_p.add_as_terminal)};
1002
1004
1006}
1007
1008void FuseInstanceNorm::apply()
1009{
1010 assert(_p.matched());
1011
1012 switch (_p.version())
1013 {
1014 case InstanceNormPattern::PatternVersion::Version_1:
1015 apply<InstanceNormPattern::PatternVersion::Version_1>();
1016 break;
1017 case InstanceNormPattern::PatternVersion::Version_2:
1018 apply<InstanceNormPattern::PatternVersion::Version_2>();
1019 break;
1020 case InstanceNormPattern::PatternVersion::Version_3:
1021 apply<InstanceNormPattern::PatternVersion::Version_3>();
1022 break;
1023 case InstanceNormPattern::PatternVersion::Version_4:
1024 apply<InstanceNormPattern::PatternVersion::Version_4>();
1025 break;
1026 case InstanceNormPattern::PatternVersion::Version_5:
1027 apply<InstanceNormPattern::PatternVersion::Version_5>();
1028 break;
1029 case InstanceNormPattern::PatternVersion::Version_6:
1030 apply<InstanceNormPattern::PatternVersion::Version_6>();
1031 break;
1032
1033 default:
1034 break;
1035 }
1036}
1037
1038}
1039
1040namespace
1041{
1042
1043class PostFusion final
1044{
1045public:
1047
1048private:
1049 uint32_t input_channel(void);
1050
1052 bool match_const_gamma_channel(void);
1053 bool match_const_beta_channel(void);
1054
1055public:
1057
1058private:
1060};
1061
1065uint32_t PostFusion::input_channel(void)
1066{
1068 if (input == nullptr)
1069 return 0;
1071 return 0;
1072
1073 auto input_rank =
input->rank();
1074 if (input_rank < 1)
1075 return 0;
1076
1077 if (input_rank == 3)
1078 {
1079
1080 return input->dim(1).value();
1081 }
1082
1083 return input->dim(input_rank - 1).value();
1084}
1085
1090{
1092
1093 auto input_chn = input_const->dim(0).value();
1094 if (input_chn == 1 && input_chn != C)
1095 {
1096 float value = input_const->
at<loco::DataType::FLOAT32>(0);
1098
1099 new_input_const = loco::must_cast<luci::CircleConst *>(clone);
1100 new_input_const->rank(1);
1101 new_input_const->dim(0).set(C);
1102 new_input_const->
size<loco::DataType::FLOAT32>(
C);
1103 for (uint32_t c = 0; c <
C; ++c)
1104 new_input_const->
at<loco::DataType::FLOAT32>(c) = value;
1105 }
1106
1107 return new_input_const;
1108}
1109
1113bool PostFusion::match_const_gamma_channel(void)
1114{
1116 if (const_as_gamma == nullptr)
1117 return false;
1118
1119 auto C = input_channel();
1120 if (C == 0)
1121 return false;
1122
1123 auto new_const_as_gamma = match_const_channel(const_as_gamma, C);
1124 if (new_const_as_gamma == nullptr)
1125 return false;
1126
1127 _inst_norm->gamma(new_const_as_gamma);
1128
1129 return true;
1130}
1131
1135bool PostFusion::match_const_beta_channel(void)
1136{
1138 if (const_as_beta == nullptr)
1139 return false;
1140
1141 auto C = input_channel();
1142 if (C == 0)
1143 return false;
1144
1145 auto new_const_as_beta = match_const_channel(const_as_beta, C);
1146 if (new_const_as_beta == nullptr)
1147 return false;
1148
1149 _inst_norm->beta(new_const_as_beta);
1150
1151 return true;
1152}
1153
1154bool PostFusion::process(void)
1155{
1156 bool changed = false;
1157
1158 if (match_const_gamma_channel())
1159 changed = true;
1160 if (match_const_beta_channel())
1161 changed = true;
1162
1163 return changed;
1164}
1165
1166}
1167
1168namespace
1169{
1170
1172{
1175
1176 return luci::fill(&p_mul, &p_const).with_commutative_args_of(add);
1177}
1178
1180{
1183
1184 if (!
luci::fill(&p_mul, &p_sub).with_commutative_args_of(add))
1185 return false;
1186
1188 if (sub == nullptr)
1189 return false;
1190
1192 if (const_as_beta == nullptr || const_as_beta->rank() != 3)
1193 return false;
1194
1195 return true;
1196}
1197
1199{
1200 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1;
1201
1202 if (is_add_input_mul_const(add))
1203 pv = InstanceNormPattern::PatternVersion::Version_2;
1204 else if (is_add_input_mul_sub3d(add))
1205 pv = InstanceNormPattern::PatternVersion::Version_6;
1206
1207 InstanceNormPattern pattern(add, pv);
1208 if (pattern.matched())
1209 {
1210 FuseInstanceNorm fuse(pattern);
1211 fuse.apply();
1212 return true;
1213 }
1214
1215 if (pv == InstanceNormPattern::PatternVersion::Version_1)
1216 {
1217
1218 pv = InstanceNormPattern::PatternVersion::Version_5;
1219 InstanceNormPattern pattern(add, pv);
1220 if (pattern.matched())
1221 {
1222 FuseInstanceNorm fuse(pattern);
1223 fuse.apply();
1224 return true;
1225 }
1226 }
1227 else if (pv == InstanceNormPattern::PatternVersion::Version_2)
1228 {
1229
1230 pv = InstanceNormPattern::PatternVersion::Version_3;
1231 InstanceNormPattern pattern(add, pv);
1232 if (pattern.matched())
1233 {
1234 FuseInstanceNorm fuse(pattern);
1235 fuse.apply();
1236 return true;
1237 }
1238 }
1239
1240 return false;
1241}
1242
1244{
1245 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_4;
1246
1247 InstanceNormPattern pattern(div, pv);
1248 if (pattern.matched())
1249 {
1250 FuseInstanceNorm fuse(pattern);
1251 fuse.apply();
1252 return true;
1253 }
1254
1255 return false;
1256}
1257
1259{
1260 PostFusion postfusion(inst_norm);
1261
1262 return postfusion.process();
1263}
1264
1265}
1266
1268{
1269
1271{
1272 bool changed = false;
1273
1274
1276 {
1278 if (not add)
1279 continue;
1280
1281 if (fuse_instance_norm(add))
1282 changed = true;
1283 }
1284
1285
1287 {
1289 if (not div)
1290 continue;
1291
1292 if (fuse_instance_norm(div))
1293 changed = true;
1294 }
1295
1296
1298 {
1300 if (not inst_norm)
1301 continue;
1302
1303 if (post_fusion(inst_norm))
1304 changed = true;
1305 }
1306
1307 return changed;
1308}
1309
1310}
Logical unit of computation.
void with(Node *into) const
Class to build tensor data.
const loco::DataTypeImpl< DT >::Type & at(uint32_t n) const
uint32_t size(void) const
loco::Node * input(void) const
SQUARED_DIFFERENCE in Circle.
#define CHECK_OR_FALSE(condition)
bool is_instance_mean_v1(luci::CircleMean *mean)
bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
bool is_instance_mean_v2(luci::CircleMean *mean)
bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Subst< SubstQualifier::Default > replace(Node *node)
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
CircleNode * clone_node(const CircleNode *node, loco::Graph *graph)
Return a new cloned CircleNode object with same attributes value of node to graph.
luci::CircleConst * clone(luci::CircleConst *node)
Return cloned object of CircleConst node.
bool run(loco::Graph *g) final
Run the pass.