32#define CHECK_OR_FALSE(condition) \
40 const auto rank = node->rank();
41 std::optional<uint32_t> depth_axis;
42 for (uint32_t axis = 0; axis < rank; ++axis)
44 if (node->dim(axis).value() != 1)
47 if (depth_axis.has_value())
54 if (!depth_axis.has_value())
58 return node->dim(depth_axis.value()).value() == depth;
71 CHECK_OR_FALSE((begin_reshape_ifm->rank() + 1) == begin_reshape->rank());
74 for (uint32_t axis = 0; axis < begin_reshape_ifm->rank(); ++axis)
77 CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).known() && begin_reshape->dim(axis).known());
78 CHECK_OR_FALSE(begin_reshape_ifm->dim(axis).value() == begin_reshape->dim(axis).value());
81 CHECK_OR_FALSE(begin_reshape->dim(begin_reshape->rank() - 1) == 1);
86 CHECK_OR_FALSE(terminal_reshape_ifm->rank() == terminal_reshape->rank() + 1);
89 CHECK_OR_FALSE(terminal_reshape_ifm->dim(begin_reshape->rank() - 1) == 1);
92 for (uint32_t axis = 0; axis < terminal_reshape->rank(); ++axis)
95 CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).known() && terminal_reshape->dim(axis).known());
96 CHECK_OR_FALSE(terminal_reshape_ifm->dim(axis).value() == terminal_reshape->dim(axis).value());
110 if (input->rank() != 4)
122 if (red_indices->rank() != 1)
124 std::set<int32_t> red_indices_set;
127 assert(red_indices->dtype() == loco::DataType::S32);
128 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
129 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
131 if (red_indices_set.size() != 2)
133 if (red_indices_set.find(1) == red_indices_set.end())
135 if (red_indices_set.find(2) == red_indices_set.end())
154 if (input->rank() != 3)
164 if (red_indices->rank() != 1)
166 std::set<int32_t> red_indices_set;
169 assert(red_indices->dtype() == loco::DataType::S32);
170 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
171 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
173 if (red_indices_set.size() != 1)
175 if (red_indices_set.find(2) == red_indices_set.end())
189 if (node->rank() != 1)
192 if (node->dim(0).value() != channel_size)
195 if (node->dtype() != loco::DataType::FLOAT32)
198 if (node->
size<loco::DataType::FLOAT32>() != channel_size)
483class InstanceNormPattern final
501 add_as_terminal = candidate;
515 reshape_as_terminal = candidate;
520 bool condition_common_1_5(uint32_t ifm_channel_depth);
521 bool condition_common_3_4();
524 template <enum PatternVersion>
bool match();
528 bool matched()
const {
return _matched; }
530 PatternVersion
version()
const {
return _pv; }
563 bool _matched =
false;
567bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
573 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
575 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
577 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
597bool InstanceNormPattern::condition_common_3_4()
613 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
615 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
634 luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
641template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
650 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
664 .with_commutative_args_of(mul_as_scaled_mean));
672template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_2>()
686 uint32_t ifm_channel_depth = ifm_node->dim(3).value();
701 CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
707 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
708 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
710 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
720 luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
732template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_3>()
747 const_one->dtype(loco::DataType::FLOAT32);
749 const_one->
size<loco::DataType::FLOAT32>(1);
750 const_one->at<loco::DataType::FLOAT32>(0) = value;
754template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_4>()
761 assert(const_as_gamma ==
nullptr);
762 assert(const_as_beta ==
nullptr);
763 assert(mul_gamma ==
nullptr);
764 assert(add_as_terminal ==
nullptr);
768 const_as_gamma = make_const_one(graph, 1.0f);
769 const_as_beta = make_const_one(graph, 0.0f);
770 const_as_gamma->name(
div->name() +
"/gamma");
771 const_as_beta->name(
div->name() +
"/beta");
777template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
786 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
796 .with_commutative_args_of(mul_as_scaled_mean));
802 auto graph = add_as_terminal->graph();
803 const_as_gamma = make_const_one(graph, 1.0f);
804 const_as_gamma->name(add_as_terminal->name() +
"/gamma");
810template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_6>()
824 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
826 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
828 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
842 uint32_t input_channel = ifm_circle->dim(1).value();
843 uint32_t input_last_dim = ifm_circle->dim(2).value();
848 const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel &&
849 (const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim));
857 .with_commutative_args_of(mul_as_scaled_mean));
863 auto graph = add_as_terminal->graph();
864 const_as_gamma = make_const_one(graph, 1.0f);
865 const_as_gamma->name(add_as_terminal->name() +
"/gamma");
871template <>
bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_7>()
873 add_as_terminal =
dynamic_cast<luci::CircleAdd *
>(reshape_as_terminal->tensor());
877 luci::fill(&mul_as_scaled_ifm, &add_neg_mul).with_commutative_args_of(add_as_terminal));
879 luci::fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
881 mul_as_scaled_mean =
dynamic_cast<luci::CircleMul *
>(add_neg_mul->x());
891 luci::fill(&mul_gamma_should_be, &neg_should_be).with_commutative_args_of(mul_as_scaled_mean));
903 ifm = reshape_of_ifm->tensor();
910 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
923 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
926 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
928 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
948bool InstanceNormPattern::matched()
957 case PatternVersion::Version_1:
958 return match<PatternVersion::Version_1>();
959 case PatternVersion::Version_2:
960 return match<PatternVersion::Version_2>();
961 case PatternVersion::Version_3:
962 return match<PatternVersion::Version_3>();
963 case PatternVersion::Version_4:
964 return match<PatternVersion::Version_4>();
965 case PatternVersion::Version_5:
966 return match<PatternVersion::Version_5>();
967 case PatternVersion::Version_6:
968 return match<PatternVersion::Version_6>();
969 case PatternVersion::Version_7:
970 return match<PatternVersion::Version_7>();
976 throw std::runtime_error(
"Invalid InstanceNorm PatternVersion.");
997class FuseInstanceNorm final
1000 FuseInstanceNorm(
const InstanceNormPattern &
p) : _p(
p) {}
1006 template <InstanceNormPattern::PatternVersion>
void apply(
void);
1009 void reshape_gamma_beta(
void);
1013 const InstanceNormPattern &_p;
1016void FuseInstanceNorm::reshape_gamma_beta()
1020 _p.const_as_gamma->rank(1);
1021 _p.const_as_gamma->dim(0).set(_p.const_as_gamma->size<loco::DataType::FLOAT32>());
1022 _p.const_as_beta->rank(1);
1023 _p.const_as_beta->dim(0).set(_p.const_as_beta->size<loco::DataType::FLOAT32>());
1034 instance_norm->
input(_p.ifm);
1035 instance_norm->gamma(_p.const_as_gamma);
1036 instance_norm->beta(_p.const_as_beta);
1037 float epsilon = _p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
1038 instance_norm->epsilon(epsilon);
1039 if (_p.add_as_terminal !=
nullptr)
1041 instance_norm->fusedActivationFunction(_p.add_as_terminal->fusedActivationFunction());
1043 instance_norm->name(
"FusedInstanceNorm/" + _p.add_as_terminal->name());
1048 assert(_p.div !=
nullptr);
1049 instance_norm->fusedActivationFunction(_p.div->fusedActivationFunction());
1050 instance_norm->name(
"FusedInstanceNorm/" + _p.div->name());
1053 return instance_norm;
1056template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_1>()
1058 auto graph = _p.add_as_terminal->graph();
1060 reshape_gamma_beta();
1062 auto instance_norm = create_inst_norm(graph);
1065 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1066 luci::get_origin(_p.mean_of_ifm),
1067 luci::get_origin(_p.sqdiff),
1068 luci::get_origin(_p.mean_as_variance),
1069 luci::get_origin(_p.add_as_variance),
1070 luci::get_origin(_p.rsqrt),
1071 luci::get_origin(_p.mul_gamma),
1072 luci::get_origin(_p.mul_as_scaled_ifm),
1073 luci::get_origin(_p.mul_as_scaled_mean),
1074 luci::get_origin(_p.sub),
1075 luci::get_origin(_p.add_as_terminal)};
1082template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_2>()
1084 auto graph = _p.add_as_terminal->graph();
1086 auto instance_norm = create_inst_norm(graph);
1089 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1090 luci::get_origin(_p.mean_of_ifm),
1091 luci::get_origin(_p.sqdiff),
1092 luci::get_origin(_p.mean_as_variance),
1093 luci::get_origin(_p.add_as_variance),
1094 luci::get_origin(_p.pow),
1095 luci::get_origin(_p.sub),
1096 luci::get_origin(_p.div),
1097 luci::get_origin(_p.mul_gamma),
1098 luci::get_origin(_p.add_as_terminal)};
1105template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_3>()
1107 auto graph = _p.add_as_terminal->graph();
1109 reshape_gamma_beta();
1111 auto instance_norm = create_inst_norm(graph);
1114 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1115 luci::get_origin(_p.mean_of_ifm),
1116 luci::get_origin(_p.sub),
1117 luci::get_origin(_p.mean_of_ifm_2),
1118 luci::get_origin(_p.sub_2),
1119 luci::get_origin(_p.square),
1120 luci::get_origin(_p.mean_as_variance),
1121 luci::get_origin(_p.sqrt),
1122 luci::get_origin(_p.add_as_variance),
1123 luci::get_origin(_p.div),
1124 luci::get_origin(_p.mul_gamma),
1125 luci::get_origin(_p.add_as_terminal)};
1132template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_4>()
1134 auto graph = _p.div->graph();
1136 auto instance_norm = create_inst_norm(graph);
1139 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1140 luci::get_origin(_p.mean_of_ifm),
1141 luci::get_origin(_p.sub),
1142 luci::get_origin(_p.mean_of_ifm_2),
1143 luci::get_origin(_p.sub_2),
1144 luci::get_origin(_p.square),
1145 luci::get_origin(_p.mean_as_variance),
1146 luci::get_origin(_p.sqrt),
1147 luci::get_origin(_p.add_as_variance),
1148 luci::get_origin(_p.div)};
1155template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
1157 auto graph = _p.add_as_terminal->graph();
1159 reshape_gamma_beta();
1161 auto instance_norm = create_inst_norm(graph);
1164 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1165 luci::get_origin(_p.mean_of_ifm),
1166 luci::get_origin(_p.sqdiff),
1167 luci::get_origin(_p.mean_as_variance),
1168 luci::get_origin(_p.add_as_variance),
1169 luci::get_origin(_p.rsqrt),
1170 luci::get_origin(_p.mul_as_scaled_ifm),
1171 luci::get_origin(_p.mul_as_scaled_mean),
1172 luci::get_origin(_p.sub),
1173 luci::get_origin(_p.add_as_terminal)};
1180template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_6>()
1182 auto graph = _p.add_as_terminal->graph();
1184 reshape_gamma_beta();
1186 auto instance_norm = create_inst_norm(graph);
1189 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1190 luci::get_origin(_p.mean_of_ifm),
1191 luci::get_origin(_p.sqdiff),
1192 luci::get_origin(_p.mean_as_variance),
1193 luci::get_origin(_p.add_as_variance),
1194 luci::get_origin(_p.rsqrt),
1195 luci::get_origin(_p.mul_as_scaled_ifm),
1196 luci::get_origin(_p.mul_as_scaled_mean),
1197 luci::get_origin(_p.sub),
1198 luci::get_origin(_p.add_as_terminal)};
1205template <>
void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_7>()
1207 auto graph = _p.reshape_as_terminal->graph();
1209 reshape_gamma_beta();
1211 auto instance_norm = create_inst_norm(graph);
1214 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
1215 luci::get_origin(_p.reshape_of_ifm),
1216 luci::get_origin(_p.mean_of_ifm),
1217 luci::get_origin(_p.sub_2),
1218 luci::get_origin(_p.square),
1219 luci::get_origin(_p.mean_as_variance),
1220 luci::get_origin(_p.add_as_variance),
1221 luci::get_origin(_p.rsqrt),
1222 luci::get_origin(_p.mul_gamma),
1223 luci::get_origin(_p.neg_mean),
1224 luci::get_origin(_p.mul_as_scaled_ifm),
1225 luci::get_origin(_p.mul_as_scaled_mean),
1226 luci::get_origin(_p.add_neg_mul),
1227 luci::get_origin(_p.add_as_terminal),
1228 luci::get_origin(_p.reshape_as_terminal)};
1232 replace(_p.reshape_as_terminal).
with(instance_norm);
1235void FuseInstanceNorm::apply()
1237 assert(_p.matched());
1239 switch (_p.version())
1241 case InstanceNormPattern::PatternVersion::Version_1:
1242 apply<InstanceNormPattern::PatternVersion::Version_1>();
1244 case InstanceNormPattern::PatternVersion::Version_2:
1245 apply<InstanceNormPattern::PatternVersion::Version_2>();
1247 case InstanceNormPattern::PatternVersion::Version_3:
1248 apply<InstanceNormPattern::PatternVersion::Version_3>();
1250 case InstanceNormPattern::PatternVersion::Version_4:
1251 apply<InstanceNormPattern::PatternVersion::Version_4>();
1253 case InstanceNormPattern::PatternVersion::Version_5:
1254 apply<InstanceNormPattern::PatternVersion::Version_5>();
1256 case InstanceNormPattern::PatternVersion::Version_6:
1257 apply<InstanceNormPattern::PatternVersion::Version_6>();
1259 case InstanceNormPattern::PatternVersion::Version_7:
1260 apply<InstanceNormPattern::PatternVersion::Version_7>();
1273class PostFusion final
1279 uint32_t input_channel(
void);
1282 bool match_const_gamma_channel(
void);
1283 bool match_const_beta_channel(
void);
1295uint32_t PostFusion::input_channel(
void)
1298 if (input ==
nullptr)
1303 auto input_rank =
input->rank();
1307 if (input_rank == 3)
1310 return input->dim(1).value();
1313 return input->dim(input_rank - 1).value();
1323 auto input_chn = input_const->dim(0).value();
1324 if (input_chn == 1 && input_chn != C)
1326 float value = input_const->
at<loco::DataType::FLOAT32>(0);
1330 new_input_const->rank(1);
1331 new_input_const->dim(0).set(C);
1332 new_input_const->
size<loco::DataType::FLOAT32>(
C);
1333 for (uint32_t c = 0; c <
C; ++c)
1334 new_input_const->
at<loco::DataType::FLOAT32>(c) = value;
1337 return new_input_const;
1343bool PostFusion::match_const_gamma_channel(
void)
1346 if (const_as_gamma ==
nullptr)
1349 auto C = input_channel();
1353 auto new_const_as_gamma = match_const_channel(const_as_gamma, C);
1354 if (new_const_as_gamma ==
nullptr)
1357 _inst_norm->gamma(new_const_as_gamma);
1365bool PostFusion::match_const_beta_channel(
void)
1368 if (const_as_beta ==
nullptr)
1371 auto C = input_channel();
1375 auto new_const_as_beta = match_const_channel(const_as_beta, C);
1376 if (new_const_as_beta ==
nullptr)
1379 _inst_norm->beta(new_const_as_beta);
1384bool PostFusion::process(
void)
1386 bool changed =
false;
1388 if (match_const_gamma_channel())
1390 if (match_const_beta_channel())
1406 return luci::fill(&p_mul, &p_const).with_commutative_args_of(add);
1414 if (!
luci::fill(&p_mul, &p_sub).with_commutative_args_of(add))
1422 if (const_as_beta ==
nullptr || const_as_beta->rank() != 3)
1430 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1;
1432 if (is_add_input_mul_const(add))
1433 pv = InstanceNormPattern::PatternVersion::Version_2;
1434 else if (is_add_input_mul_sub3d(add))
1435 pv = InstanceNormPattern::PatternVersion::Version_6;
1437 InstanceNormPattern pattern(add, pv);
1438 if (pattern.matched())
1440 FuseInstanceNorm fuse(pattern);
1445 if (pv == InstanceNormPattern::PatternVersion::Version_1)
1448 pv = InstanceNormPattern::PatternVersion::Version_5;
1449 InstanceNormPattern pattern(add, pv);
1450 if (pattern.matched())
1452 FuseInstanceNorm fuse(pattern);
1457 else if (pv == InstanceNormPattern::PatternVersion::Version_2)
1460 pv = InstanceNormPattern::PatternVersion::Version_3;
1461 InstanceNormPattern pattern(add, pv);
1462 if (pattern.matched())
1464 FuseInstanceNorm fuse(pattern);
1475 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_4;
1477 InstanceNormPattern pattern(div, pv);
1478 if (pattern.matched())
1480 FuseInstanceNorm fuse(pattern);
1490 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_7;
1492 InstanceNormPattern pattern(reshape, pv);
1493 if (pattern.matched())
1495 FuseInstanceNorm fuse(pattern);
1505 PostFusion postfusion(inst_norm);
1507 return postfusion.process();
1517 bool changed =
false;
Logical unit of computation.
void with(Node *into) const
Class to build tensor data.
const loco::DataTypeImpl< DT >::Type & at(uint32_t n) const
uint32_t size(void) const
loco::Node * input(void) const
bool keep_dims(void) const
loco::Node * input(void) const
loco::Node * reduction_indices(void) const
loco::Node * tensor(void) const
SQUARED_DIFFERENCE in Circle.
#define CHECK_OR_FALSE(condition)
bool is_instance_mean_v1(luci::CircleMean *mean)
bool is_unsqueeze_squeeze_pair(luci::CircleReshape *begin_reshape, luci::CircleReshape *terminal_reshape)
bool is_instance_mean_v2(luci::CircleMean *mean)
bool is_unsqueezed_1D(luci::CircleConst *node, uint32_t depth)
bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
T must_cast(FeatureEncoder *node)
A helper dynamic_cast that throws when failed.
std::vector< Node * > output_nodes(Graph *)
Subst< SubstQualifier::Default > replace(Node *node)
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
T must_cast(loco::Node *node)
CircleNode * clone_node(const CircleNode *node, loco::Graph *graph)
Return a new cloned CircleNode object with same attributes value of node to graph.
luci::CircleConst * clone(luci::CircleConst *node)
Return cloned object of CircleConst node.
bool run(loco::Graph *g) final
Run the pass.