34 auto rank = node->rank();
36 for (axis = 0; axis < rank - 1; ++axis)
38 if (node->dim(axis).value() != 1)
41 return node->dim(axis).value() == depth;
49 auto input = loco::must_cast<luci::CircleNode *>(mean->
input());
52 if (input->rank() != 4)
64 if (red_indices->rank() != 1)
66 std::set<int32_t> red_indices_set;
69 assert(red_indices->dtype() == loco::DataType::S32);
70 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
71 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
73 if (red_indices_set.size() != 2)
75 if (red_indices_set.find(1) == red_indices_set.end())
77 if (red_indices_set.find(2) == red_indices_set.end())
93 auto input = loco::must_cast<luci::CircleNode *>(mean->
input());
96 if (input->rank() != 3)
106 if (red_indices->rank() != 1)
108 std::set<int32_t> red_indices_set;
111 assert(red_indices->dtype() == loco::DataType::S32);
112 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
113 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
115 if (red_indices_set.size() != 1)
117 if (red_indices_set.find(2) == red_indices_set.end())
131 if (node->rank() != 1)
134 if (node->dim(0).value() != channel_size)
137 if (node->dtype() != loco::DataType::FLOAT32)
140 if (node->
size<loco::DataType::FLOAT32>() != channel_size)
374class InstanceNormPattern final
391 add_as_terminal = candidate;
403 bool condition_common_1_5(uint32_t ifm_channel_depth);
404 bool condition_common_3_4();
407 template <enum PatternVersion>
bool match();
411 bool matched()
const {
return _matched; }
413 PatternVersion version()
const {
return _pv; }
442 bool _matched =
false;
446#define CHECK_OR_FALSE(condition) \
447 if (not(condition)) \
450bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
456 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
458 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
460 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
480bool InstanceNormPattern::condition_common_3_4()
496 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
498 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
517 luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
524template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
529 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
533 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
547 .with_commutative_args_of(mul_as_scaled_mean));
555template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_2>()
569 uint32_t ifm_channel_depth = ifm_node->dim(3).value();
584 CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
590 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
591 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
593 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
603 luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
615template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_3>()
630 const_one->dtype(loco::DataType::FLOAT32);
632 const_one->
size<loco::DataType::FLOAT32>(1);
633 const_one->at<loco::DataType::FLOAT32>(0) = value;
637template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_4>()
644 assert(const_as_gamma ==
nullptr);
645 assert(const_as_beta ==
nullptr);
646 assert(mul_gamma ==
nullptr);
647 assert(add_as_terminal ==
nullptr);
651 const_as_gamma = make_const_one(graph, 1.0f);
652 const_as_beta = make_const_one(graph, 0.0f);
653 const_as_gamma->name(
div->name() +
"/gamma");
654 const_as_beta->name(
div->name() +
"/beta");
660template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
665 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
669 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
679 .with_commutative_args_of(mul_as_scaled_mean));
685 auto graph = add_as_terminal->graph();
686 const_as_gamma = make_const_one(graph, 1.0f);
687 const_as_gamma->name(add_as_terminal->name() +
"/gamma");
693template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_6>()
698 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
707 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
709 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
711 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
725 uint32_t input_channel = ifm_circle->dim(1).value();
726 uint32_t input_last_dim = ifm_circle->dim(2).value();
731 const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel &&
732 (const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim));
740 .with_commutative_args_of(mul_as_scaled_mean));
746 auto graph = add_as_terminal->graph();
747 const_as_gamma = make_const_one(graph, 1.0f);
748 const_as_gamma->name(add_as_terminal->name() +
"/gamma");
754bool InstanceNormPattern::matched()
763 case PatternVersion::Version_1:
764 return match<PatternVersion::Version_1>();
765 case PatternVersion::Version_2:
766 return match<PatternVersion::Version_2>();
767 case PatternVersion::Version_3:
768 return match<PatternVersion::Version_3>();
769 case PatternVersion::Version_4:
770 return match<PatternVersion::Version_4>();
771 case PatternVersion::Version_5:
772 return match<PatternVersion::Version_5>();
773 case PatternVersion::Version_6:
774 return match<PatternVersion::Version_6>();
780 throw std::runtime_error(
"Invalid InstanceNorm PatternVersion.");
801class FuseInstanceNorm final
804 FuseInstanceNorm(
const InstanceNormPattern &p) : _p(p) {}
810 template <InstanceNormPattern::PatternVersion>
void apply(
void);
813 void reshape_gamma_beta(
void);
817 const InstanceNormPattern &_p;
820void FuseInstanceNorm::reshape_gamma_beta()
824 _p.const_as_gamma->rank(1);
825 _p.const_as_gamma->dim(0).set(_p.const_as_gamma->size<loco::DataType::FLOAT32>());
826 _p.const_as_beta->rank(1);
827 _p.const_as_beta->dim(0).set(_p.const_as_beta->size<loco::DataType::FLOAT32>());
838 instance_norm->
input(_p.ifm);
839 instance_norm->gamma(_p.const_as_gamma);
840 instance_norm->beta(_p.const_as_beta);
841 float epsilon = _p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
842 instance_norm->epsilon(epsilon);
843 if (_p.add_as_terminal !=
nullptr)
845 instance_norm->fusedActivationFunction(_p.add_as_terminal->fusedActivationFunction());
847 instance_norm->name(
"FusedInstanceNorm/" + _p.add_as_terminal->name());
852 assert(_p.div !=
nullptr);
853 instance_norm->fusedActivationFunction(_p.div->fusedActivationFunction());
854 instance_norm->name(
"FusedInstanceNorm/" + _p.div->name());
857 return instance_norm;
860template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_1>()
862 auto graph = _p.add_as_terminal->graph();
864 reshape_gamma_beta();
866 auto instance_norm = create_inst_norm(graph);
869 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
870 luci::get_origin(_p.mean_of_ifm),
871 luci::get_origin(_p.sqdiff),
872 luci::get_origin(_p.mean_as_variance),
873 luci::get_origin(_p.add_as_variance),
874 luci::get_origin(_p.rsqrt),
875 luci::get_origin(_p.mul_gamma),
876 luci::get_origin(_p.mul_as_scaled_ifm),
877 luci::get_origin(_p.mul_as_scaled_mean),
878 luci::get_origin(_p.sub),
879 luci::get_origin(_p.add_as_terminal)};
886template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_2>()
888 auto graph = _p.add_as_terminal->graph();
890 auto instance_norm = create_inst_norm(graph);
893 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
894 luci::get_origin(_p.mean_of_ifm),
895 luci::get_origin(_p.sqdiff),
896 luci::get_origin(_p.mean_as_variance),
897 luci::get_origin(_p.add_as_variance),
898 luci::get_origin(_p.pow),
899 luci::get_origin(_p.sub),
900 luci::get_origin(_p.div),
901 luci::get_origin(_p.mul_gamma),
902 luci::get_origin(_p.add_as_terminal)};
909template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_3>()
911 auto graph = _p.add_as_terminal->graph();
913 reshape_gamma_beta();
915 auto instance_norm = create_inst_norm(graph);
918 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
919 luci::get_origin(_p.mean_of_ifm),
920 luci::get_origin(_p.sub),
921 luci::get_origin(_p.mean_of_ifm_2),
922 luci::get_origin(_p.sub_2),
923 luci::get_origin(_p.square),
924 luci::get_origin(_p.mean_as_variance),
925 luci::get_origin(_p.sqrt),
926 luci::get_origin(_p.add_as_variance),
927 luci::get_origin(_p.div),
928 luci::get_origin(_p.mul_gamma),
929 luci::get_origin(_p.add_as_terminal)};
936template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_4>()
938 auto graph = _p.div->graph();
940 auto instance_norm = create_inst_norm(graph);
943 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
944 luci::get_origin(_p.mean_of_ifm),
945 luci::get_origin(_p.sub),
946 luci::get_origin(_p.mean_of_ifm_2),
947 luci::get_origin(_p.sub_2),
948 luci::get_origin(_p.square),
949 luci::get_origin(_p.mean_as_variance),
950 luci::get_origin(_p.sqrt),
951 luci::get_origin(_p.add_as_variance),
952 luci::get_origin(_p.div)};
959template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
961 auto graph = _p.add_as_terminal->graph();
963 reshape_gamma_beta();
965 auto instance_norm = create_inst_norm(graph);
968 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
969 luci::get_origin(_p.mean_of_ifm),
970 luci::get_origin(_p.sqdiff),
971 luci::get_origin(_p.mean_as_variance),
972 luci::get_origin(_p.add_as_variance),
973 luci::get_origin(_p.rsqrt),
974 luci::get_origin(_p.mul_as_scaled_ifm),
975 luci::get_origin(_p.mul_as_scaled_mean),
976 luci::get_origin(_p.sub),
977 luci::get_origin(_p.add_as_terminal)};
984template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_6>()
986 auto graph = _p.add_as_terminal->graph();
988 reshape_gamma_beta();
990 auto instance_norm = create_inst_norm(graph);
993 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
994 luci::get_origin(_p.mean_of_ifm),
995 luci::get_origin(_p.sqdiff),
996 luci::get_origin(_p.mean_as_variance),
997 luci::get_origin(_p.add_as_variance),
998 luci::get_origin(_p.rsqrt),
999 luci::get_origin(_p.mul_as_scaled_ifm),
1000 luci::get_origin(_p.mul_as_scaled_mean),
1001 luci::get_origin(_p.sub),
1002 luci::get_origin(_p.add_as_terminal)};
1009void FuseInstanceNorm::apply()
1011 assert(_p.matched());
1013 switch (_p.version())
1015 case InstanceNormPattern::PatternVersion::Version_1:
1016 apply<InstanceNormPattern::PatternVersion::Version_1>();
1018 case InstanceNormPattern::PatternVersion::Version_2:
1019 apply<InstanceNormPattern::PatternVersion::Version_2>();
1021 case InstanceNormPattern::PatternVersion::Version_3:
1022 apply<InstanceNormPattern::PatternVersion::Version_3>();
1024 case InstanceNormPattern::PatternVersion::Version_4:
1025 apply<InstanceNormPattern::PatternVersion::Version_4>();
1027 case InstanceNormPattern::PatternVersion::Version_5:
1028 apply<InstanceNormPattern::PatternVersion::Version_5>();
1030 case InstanceNormPattern::PatternVersion::Version_6:
1031 apply<InstanceNormPattern::PatternVersion::Version_6>();
1044class PostFusion final
1050 uint32_t input_channel(
void);
1053 bool match_const_gamma_channel(
void);
1054 bool match_const_beta_channel(
void);
1066uint32_t PostFusion::input_channel(
void)
1069 if (input ==
nullptr)
1074 auto input_rank =
input->rank();
1078 if (input_rank == 3)
1081 return input->dim(1).value();
1084 return input->dim(input_rank - 1).value();
1094 auto input_chn = input_const->dim(0).value();
1095 if (input_chn == 1 && input_chn != C)
1097 float value = input_const->
at<loco::DataType::FLOAT32>(0);
1100 new_input_const = loco::must_cast<luci::CircleConst *>(clone);
1101 new_input_const->rank(1);
1102 new_input_const->dim(0).set(C);
1103 new_input_const->
size<loco::DataType::FLOAT32>(
C);
1104 for (uint32_t c = 0; c <
C; ++c)
1105 new_input_const->
at<loco::DataType::FLOAT32>(c) = value;
1108 return new_input_const;
1114bool PostFusion::match_const_gamma_channel(
void)
1117 if (const_as_gamma ==
nullptr)
1120 auto C = input_channel();
1124 auto new_const_as_gamma = match_const_channel(const_as_gamma, C);
1125 if (new_const_as_gamma ==
nullptr)
1128 _inst_norm->gamma(new_const_as_gamma);
1136bool PostFusion::match_const_beta_channel(
void)
1139 if (const_as_beta ==
nullptr)
1142 auto C = input_channel();
1146 auto new_const_as_beta = match_const_channel(const_as_beta, C);
1147 if (new_const_as_beta ==
nullptr)
1150 _inst_norm->beta(new_const_as_beta);
1155bool PostFusion::process(
void)
1157 bool changed =
false;
1159 if (match_const_gamma_channel())
1161 if (match_const_beta_channel())
1177 return luci::fill(&p_mul, &p_const).with_commutative_args_of(add);
1185 if (!
luci::fill(&p_mul, &p_sub).with_commutative_args_of(add))
1193 if (const_as_beta ==
nullptr || const_as_beta->rank() != 3)
1201 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1;
1203 if (is_add_input_mul_const(add))
1204 pv = InstanceNormPattern::PatternVersion::Version_2;
1205 else if (is_add_input_mul_sub3d(add))
1206 pv = InstanceNormPattern::PatternVersion::Version_6;
1208 InstanceNormPattern pattern(add, pv);
1209 if (pattern.matched())
1211 FuseInstanceNorm fuse(pattern);
1216 if (pv == InstanceNormPattern::PatternVersion::Version_1)
1219 pv = InstanceNormPattern::PatternVersion::Version_5;
1220 InstanceNormPattern pattern(add, pv);
1221 if (pattern.matched())
1223 FuseInstanceNorm fuse(pattern);
1228 else if (pv == InstanceNormPattern::PatternVersion::Version_2)
1231 pv = InstanceNormPattern::PatternVersion::Version_3;
1232 InstanceNormPattern pattern(add, pv);
1233 if (pattern.matched())
1235 FuseInstanceNorm fuse(pattern);
1246 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_4;
1248 InstanceNormPattern pattern(div, pv);
1249 if (pattern.matched())
1251 FuseInstanceNorm fuse(pattern);
1261 PostFusion postfusion(inst_norm);
1263 return postfusion.process();
1273 bool changed =
false;
1282 if (fuse_instance_norm(add))
1293 if (fuse_instance_norm(div))
1304 if (post_fusion(inst_norm))
Logical unit of computation.
void with(Node *into) const
Class to build tensor data.
const loco::DataTypeImpl< DT >::Type & at(uint32_t n) const
uint32_t size(void) const
loco::Node * input(void) const
bool keep_dims(void) const
loco::Node * input(void) const
loco::Node * reduction_indices(void) const
SQUARED_DIFFERENCE in Circle.
#define CHECK_OR_FALSE(condition)
bool is_instance_mean_v1(luci::CircleMean *mean)
bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
bool is_instance_mean_v2(luci::CircleMean *mean)
bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Subst< SubstQualifier::Default > replace(Node *node)
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
CircleNode * clone_node(const CircleNode *node, loco::Graph *graph)
Return a new cloned CircleNode object with same attributes value of node to graph.
luci::CircleConst * clone(luci::CircleConst *node)
Return cloned object of CircleConst node.
bool run(loco::Graph *g) final
Run the pass.