ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
FuseInstanceNormPass.cpp File Reference

Go to the source code of this file.

Namespaces

namespace  luci
 

Macros

#define CHECK_OR_FALSE(condition)
 

Functions

bool is_1D_with_dummy_dim (luci::CircleConst *node, uint32_t depth)
 
bool is_instance_mean_v1 (luci::CircleMean *mean)
 
bool is_instance_mean_v2 (luci::CircleMean *mean)
 
bool is_1D_float32_const (const luci::CircleConst *node, uint32_t channel_size)
 

Macro Definition Documentation

◆ CHECK_OR_FALSE

#define CHECK_OR_FALSE (   condition)
Value:
if (not(condition)) \
return false;

Definition at line 446 of file FuseInstanceNormPass.cpp.

450{
451 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
452 CHECK_OR_FALSE(add_as_variance);
453
455 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
456
457 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
458 // TODO Support regarding broadcast
459 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
460
461 CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
462
463 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
464 CHECK_OR_FALSE(sqdiff);
465
466 loco::Node *ifm_should_be = nullptr;
467 CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
468 CHECK_OR_FALSE(ifm == ifm_should_be);
470 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
471
472 const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
473 CHECK_OR_FALSE(const_as_beta);
474 CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
475
476 return true;
477}
478
479bool InstanceNormPattern::condition_common_3_4()
480{
481 // check left sub
482 ifm = sub->x();
483 CHECK_OR_FALSE(ifm);
484
485 luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
486 CHECK_OR_FALSE(ifm_node->rank() == 4);
487 CHECK_OR_FALSE(ifm_node->dim(3).known());
488
489 mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
490 CHECK_OR_FALSE(mean_of_ifm);
491 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
492
493 // continue search from add_as_variance
494 CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance));
495 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
496 // TODO Support regarding broadcast
497 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
498
499 mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x());
500 CHECK_OR_FALSE(mean_as_variance);
501
502 square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
503 CHECK_OR_FALSE(square);
504
505 sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
506 CHECK_OR_FALSE(sub_2);
507 CHECK_OR_FALSE(ifm == sub_2->x());
508
509 mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y());
510 CHECK_OR_FALSE(mean_of_ifm_2);
511 CHECK_OR_FALSE(ifm == mean_of_ifm_2->input());
512
513 loco::Node *ifm_should_be = nullptr;
514 luci::CircleMean *mean_of_ifm_2_should_be = nullptr;
516 luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
517 CHECK_OR_FALSE(ifm == ifm_should_be);
518 CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be);
519
520 return true;
521}
522
523template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
524{
525 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
526 CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
527
528 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
529 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
530 CHECK_OR_FALSE(ifm_circle->rank() == 4);
531 CHECK_OR_FALSE(ifm_circle->dim(3).known());
532 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
533
534 CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
535
536 CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
537
538 CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
539
540 luci::CircleMul *mul_gamma_should_be = nullptr;
541 luci::CircleMean *mean_of_ifm_should_be = nullptr;
542
543 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
544 CHECK_OR_FALSE(mul_as_scaled_mean);
545 CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
546 .with_commutative_args_of(mul_as_scaled_mean));
547 CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
548 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
549
550 _matched = true;
551 return true;
552}
553
554template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_2>()
555{
556 CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
557 CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
558
559 sub = dynamic_cast<luci::CircleSub *>(div->x());
560 CHECK_OR_FALSE(sub);
561
562 ifm = sub->x();
563 CHECK_OR_FALSE(ifm);
564
565 luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
566 CHECK_OR_FALSE(ifm_node->rank() == 4);
567 CHECK_OR_FALSE(ifm_node->dim(3).known());
568 uint32_t ifm_channel_depth = ifm_node->dim(3).value();
569
570 mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
571 CHECK_OR_FALSE(mean_of_ifm);
572
573 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
574
575 pow = dynamic_cast<luci::CirclePow *>(div->y());
576 CHECK_OR_FALSE(pow);
577
578 add_as_variance = dynamic_cast<luci::CircleAdd *>(pow->x());
579 CHECK_OR_FALSE(add_as_variance);
580
581 luci::CircleConst *zero_point_five = dynamic_cast<luci::CircleConst *>(pow->y());
582 CHECK_OR_FALSE(zero_point_five);
583 CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
584 // TODO Support regarding broadcast
585 CHECK_OR_FALSE(zero_point_five->size<loco::DataType::FLOAT32>() == 1);
586 CHECK_OR_FALSE(zero_point_five->at<loco::DataType::FLOAT32>(0) == 0.5);
587
589 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
590 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
591 // TODO Support regarding broadcast
592 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
593
594 CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
595
596 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
597 CHECK_OR_FALSE(sqdiff);
598
599 loco::Node *ifm_should_be = nullptr;
600 luci::CircleMean *mean_of_ifm_should_be = nullptr;
602 luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
603 CHECK_OR_FALSE(ifm == ifm_should_be);
604 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
605
606 // Check for channel size
607 CHECK_OR_FALSE(is_1D_float32_const(const_as_gamma, ifm_channel_depth));
608 CHECK_OR_FALSE(is_1D_float32_const(const_as_beta, ifm_channel_depth));
609
610 _matched = true;
611 return true;
612}
613
614template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_3>()
615{
616 CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
617 CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
618 CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
619
620 CHECK_OR_FALSE(condition_common_3_4());
621
622 _matched = true;
623 return true;
624}
625
626luci::CircleConst *make_const_one(loco::Graph *graph, float value)
627{
628 auto const_one = graph->nodes()->create<luci::CircleConst>();
629 const_one->dtype(loco::DataType::FLOAT32);
630 const_one->rank(1);
631 const_one->size<loco::DataType::FLOAT32>(1);
632 const_one->at<loco::DataType::FLOAT32>(0) = value;
633 return const_one;
634}
635
636template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_4>()
637{
638 CHECK_OR_FALSE(div);
639 CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
640
641 CHECK_OR_FALSE(condition_common_3_4());
642
643 assert(const_as_gamma == nullptr);
644 assert(const_as_beta == nullptr);
645 assert(mul_gamma == nullptr);
646 assert(add_as_terminal == nullptr);
647
648 // create 1.0 gamma and 0.0 beta
649 auto graph = div->graph();
650 const_as_gamma = make_const_one(graph, 1.0f);
651 const_as_beta = make_const_one(graph, 0.0f);
652 const_as_gamma->name(div->name() + "/gamma");
653 const_as_beta->name(div->name() + "/beta");
654
655 _matched = true;
656 return true;
657}
658
659template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
660{
661 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
662 CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
663
664 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
665 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
666 CHECK_OR_FALSE(ifm_circle->rank() == 4);
667 CHECK_OR_FALSE(ifm_circle->dim(3).known());
668 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
669
670 CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
671
672 luci::CircleRsqrt *rsqrt_should_be = nullptr;
673 luci::CircleMean *mean_of_ifm_should_be = nullptr;
674
675 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
676 CHECK_OR_FALSE(mul_as_scaled_mean);
677 CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
678 .with_commutative_args_of(mul_as_scaled_mean));
679 CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
680 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
681
682 // mul_gamma is absent
683 // const_as_gamma assume to be 1.0
684 auto graph = add_as_terminal->graph();
685 const_as_gamma = make_const_one(graph, 1.0f);
686 const_as_gamma->name(add_as_terminal->name() + "/gamma");
687
688 _matched = true;
689 return true;
690}
691
692template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_6>()
693{
694 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
695 CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
696
697 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
698 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
699 CHECK_OR_FALSE(ifm_circle->rank() == 3);
700 CHECK_OR_FALSE((ifm_circle->dim(1).known()));
701
702 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
703 CHECK_OR_FALSE(add_as_variance);
704
706 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
707
708 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
709 // TODO Support regarding broadcast
710 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
711
712 CHECK_OR_FALSE(is_instance_mean_v2(mean_as_variance));
713
714 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
715 CHECK_OR_FALSE(sqdiff);
716
717 loco::Node *ifm_should_be = nullptr;
718 CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
719 CHECK_OR_FALSE(ifm == ifm_should_be);
721 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
722
723 // If const_as_beta has shape of '1 x chennel x (1 or input last dimension)'
724 uint32_t input_channel = ifm_circle->dim(1).value();
725 uint32_t input_last_dim = ifm_circle->dim(2).value();
726 const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
727 CHECK_OR_FALSE(const_as_beta);
728 CHECK_OR_FALSE(const_as_beta->rank() == 3);
730 const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel &&
731 (const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim));
732
733 luci::CircleRsqrt *rsqrt_should_be = nullptr;
734 luci::CircleMean *mean_of_ifm_should_be = nullptr;
735
736 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
737 CHECK_OR_FALSE(mul_as_scaled_mean);
738 CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
739 .with_commutative_args_of(mul_as_scaled_mean));
740 CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
741 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
742
743 // mul_gamma is absent
744 // const_as_gamma assume to be 1.0
745 auto graph = add_as_terminal->graph();
746 const_as_gamma = make_const_one(graph, 1.0f);
747 const_as_gamma->name(add_as_terminal->name() + "/gamma");
748
749 _matched = true;
750 return true;
751}
752
753bool InstanceNormPattern::matched()
754{
755 if (_matched)
756 return true;
757
758 // Check order is DFS
759
760 switch (_pv)
761 {
762 case PatternVersion::Version_1:
763 return match<PatternVersion::Version_1>();
764 case PatternVersion::Version_2:
765 return match<PatternVersion::Version_2>();
766 case PatternVersion::Version_3:
767 return match<PatternVersion::Version_3>();
768 case PatternVersion::Version_4:
769 return match<PatternVersion::Version_4>();
770 case PatternVersion::Version_5:
771 return match<PatternVersion::Version_5>();
772 case PatternVersion::Version_6:
773 return match<PatternVersion::Version_6>();
774
775 default:
776 break;
777 }
778
779 throw std::runtime_error("Invalid InstanceNorm PatternVersion.");
780}
781
782#undef CHECK_OR_FALSE
783
800class FuseInstanceNorm final
801{
802public:
803 FuseInstanceNorm(const InstanceNormPattern &p) : _p(p) {}
804
805public:
806 void apply(void);
807
808private:
809 template <InstanceNormPattern::PatternVersion> void apply(void);
810
811private:
812 void reshape_gamma_beta(void);
813 luci::CircleInstanceNorm *create_inst_norm(loco::Graph *graph);
814
815private:
816 const InstanceNormPattern &_p;
817};
818
819void FuseInstanceNorm::reshape_gamma_beta()
820{
821 // Version 1 and 3 need to reshape
822 {
823 _p.const_as_gamma->rank(1);
824 _p.const_as_gamma->dim(0).set(_p.const_as_gamma->size<loco::DataType::FLOAT32>());
825 _p.const_as_beta->rank(1);
826 _p.const_as_beta->dim(0).set(_p.const_as_beta->size<loco::DataType::FLOAT32>());
827
828 _p.const_as_gamma->shape_status(luci::ShapeStatus::UNDEFINED);
829 _p.const_as_beta->shape_status(luci::ShapeStatus::UNDEFINED);
830 }
831}
832
833luci::CircleInstanceNorm *FuseInstanceNorm::create_inst_norm(loco::Graph *graph)
834{
835 // Make Instance Norm to replace
836 auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>();
837 instance_norm->input(_p.ifm);
838 instance_norm->gamma(_p.const_as_gamma);
839 instance_norm->beta(_p.const_as_beta);
840 float epsilon = _p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
841 instance_norm->epsilon(epsilon);
842 if (_p.add_as_terminal != nullptr)
843 {
844 instance_norm->fusedActivationFunction(_p.add_as_terminal->fusedActivationFunction());
845 // NOTE unique name should be assigned in export
846 instance_norm->name("FusedInstanceNorm/" + _p.add_as_terminal->name());
847 }
848 else
849 {
850 // VERSION_4
851 assert(_p.div != nullptr);
852 instance_norm->fusedActivationFunction(_p.div->fusedActivationFunction());
853 instance_norm->name("FusedInstanceNorm/" + _p.div->name());
854 }
855
856 return instance_norm;
857}
858
859template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_1>()
860{
861 auto graph = _p.add_as_terminal->graph();
862
863 reshape_gamma_beta();
864
865 auto instance_norm = create_inst_norm(graph);
866
867 // set origin
868 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
869 luci::get_origin(_p.mean_of_ifm),
870 luci::get_origin(_p.sqdiff),
871 luci::get_origin(_p.mean_as_variance),
872 luci::get_origin(_p.add_as_variance),
873 luci::get_origin(_p.rsqrt),
874 luci::get_origin(_p.mul_gamma),
875 luci::get_origin(_p.mul_as_scaled_ifm),
876 luci::get_origin(_p.mul_as_scaled_mean),
877 luci::get_origin(_p.sub),
878 luci::get_origin(_p.add_as_terminal)};
879
880 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
881
882 replace(_p.add_as_terminal).with(instance_norm);
883}
884
885template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_2>()
886{
887 auto graph = _p.add_as_terminal->graph();
888
889 auto instance_norm = create_inst_norm(graph);
890
891 // set origin
892 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
893 luci::get_origin(_p.mean_of_ifm),
894 luci::get_origin(_p.sqdiff),
895 luci::get_origin(_p.mean_as_variance),
896 luci::get_origin(_p.add_as_variance),
897 luci::get_origin(_p.pow),
898 luci::get_origin(_p.sub),
899 luci::get_origin(_p.div),
900 luci::get_origin(_p.mul_gamma),
901 luci::get_origin(_p.add_as_terminal)};
902
903 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
904
905 replace(_p.add_as_terminal).with(instance_norm);
906}
907
908template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_3>()
909{
910 auto graph = _p.add_as_terminal->graph();
911
912 reshape_gamma_beta();
913
914 auto instance_norm = create_inst_norm(graph);
915
916 // set origin
917 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
918 luci::get_origin(_p.mean_of_ifm),
919 luci::get_origin(_p.sub),
920 luci::get_origin(_p.mean_of_ifm_2),
921 luci::get_origin(_p.sub_2),
922 luci::get_origin(_p.square),
923 luci::get_origin(_p.mean_as_variance),
924 luci::get_origin(_p.sqrt),
925 luci::get_origin(_p.add_as_variance),
926 luci::get_origin(_p.div),
927 luci::get_origin(_p.mul_gamma),
928 luci::get_origin(_p.add_as_terminal)};
929
930 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
931
932 replace(_p.add_as_terminal).with(instance_norm);
933}
934
935template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_4>()
936{
937 auto graph = _p.div->graph();
938
939 auto instance_norm = create_inst_norm(graph);
940
941 // set origin
942 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
943 luci::get_origin(_p.mean_of_ifm),
944 luci::get_origin(_p.sub),
945 luci::get_origin(_p.mean_of_ifm_2),
946 luci::get_origin(_p.sub_2),
947 luci::get_origin(_p.square),
948 luci::get_origin(_p.mean_as_variance),
949 luci::get_origin(_p.sqrt),
950 luci::get_origin(_p.add_as_variance),
951 luci::get_origin(_p.div)};
952
953 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
954
955 replace(_p.div).with(instance_norm);
956}
957
958template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
959{
960 auto graph = _p.add_as_terminal->graph();
961
962 reshape_gamma_beta();
963
964 auto instance_norm = create_inst_norm(graph);
965
966 // set origin
967 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
968 luci::get_origin(_p.mean_of_ifm),
969 luci::get_origin(_p.sqdiff),
970 luci::get_origin(_p.mean_as_variance),
971 luci::get_origin(_p.add_as_variance),
972 luci::get_origin(_p.rsqrt),
973 luci::get_origin(_p.mul_as_scaled_ifm),
974 luci::get_origin(_p.mul_as_scaled_mean),
975 luci::get_origin(_p.sub),
976 luci::get_origin(_p.add_as_terminal)};
977
978 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
979
980 replace(_p.add_as_terminal).with(instance_norm);
981}
982
983template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_6>()
984{
985 auto graph = _p.add_as_terminal->graph();
986
987 reshape_gamma_beta();
988
989 auto instance_norm = create_inst_norm(graph);
990
991 // set origin
992 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
993 luci::get_origin(_p.mean_of_ifm),
994 luci::get_origin(_p.sqdiff),
995 luci::get_origin(_p.mean_as_variance),
996 luci::get_origin(_p.add_as_variance),
997 luci::get_origin(_p.rsqrt),
998 luci::get_origin(_p.mul_as_scaled_ifm),
999 luci::get_origin(_p.mul_as_scaled_mean),
1000 luci::get_origin(_p.sub),
1001 luci::get_origin(_p.add_as_terminal)};
1002
1003 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1004
1005 replace(_p.add_as_terminal).with(instance_norm);
1006}
1007
1008void FuseInstanceNorm::apply()
1009{
1010 assert(_p.matched());
1011
1012 switch (_p.version())
1013 {
1014 case InstanceNormPattern::PatternVersion::Version_1:
1015 apply<InstanceNormPattern::PatternVersion::Version_1>();
1016 break;
1017 case InstanceNormPattern::PatternVersion::Version_2:
1018 apply<InstanceNormPattern::PatternVersion::Version_2>();
1019 break;
1020 case InstanceNormPattern::PatternVersion::Version_3:
1021 apply<InstanceNormPattern::PatternVersion::Version_3>();
1022 break;
1023 case InstanceNormPattern::PatternVersion::Version_4:
1024 apply<InstanceNormPattern::PatternVersion::Version_4>();
1025 break;
1026 case InstanceNormPattern::PatternVersion::Version_5:
1027 apply<InstanceNormPattern::PatternVersion::Version_5>();
1028 break;
1029 case InstanceNormPattern::PatternVersion::Version_6:
1030 apply<InstanceNormPattern::PatternVersion::Version_6>();
1031 break;
1032
1033 default:
1034 break;
1035 }
1036}
1037
1038} // namespace
1039
1040namespace
1041{
1042
1043class PostFusion final
1044{
1045public:
1046 PostFusion(luci::CircleInstanceNorm *inst_norm) : _inst_norm(inst_norm) {}
1047
1048private:
1049 uint32_t input_channel(void);
1050
1051 luci::CircleConst *match_const_channel(luci::CircleConst *, uint32_t);
1052 bool match_const_gamma_channel(void);
1053 bool match_const_beta_channel(void);
1054
1055public:
1056 bool process(void);
1057
1058private:
1059 luci::CircleInstanceNorm *_inst_norm = nullptr;
1060};
1061
1065uint32_t PostFusion::input_channel(void)
1066{
1067 auto input = dynamic_cast<luci::CircleNode *>(_inst_norm->input());
1068 if (input == nullptr)
1069 return 0;
1070 if (input->shape_status() != luci::ShapeStatus::VALID)
1071 return 0;
1072
1073 auto input_rank = input->rank();
1074 if (input_rank < 1)
1075 return 0;
1076
1077 if (input_rank == 3)
1078 {
1079 // use dim 1
1080 return input->dim(1).value();
1081 }
1082 // assume channel-last
1083 return input->dim(input_rank - 1).value();
1084}
1085
1089luci::CircleConst *PostFusion::match_const_channel(luci::CircleConst *input_const, uint32_t C)
1090{
1091 luci::CircleConst *new_input_const = nullptr;
1092
1093 auto input_chn = input_const->dim(0).value();
1094 if (input_chn == 1 && input_chn != C)
1095 {
1096 float value = input_const->at<loco::DataType::FLOAT32>(0);
1097 auto clone = luci::clone_node(input_const, input_const->graph());
1098
1099 new_input_const = loco::must_cast<luci::CircleConst *>(clone);
1100 new_input_const->rank(1);
1101 new_input_const->dim(0).set(C);
1102 new_input_const->size<loco::DataType::FLOAT32>(C);
1103 for (uint32_t c = 0; c < C; ++c)
1104 new_input_const->at<loco::DataType::FLOAT32>(c) = value;
1105 }
1106
1107 return new_input_const;
1108}
1109
1113bool PostFusion::match_const_gamma_channel(void)
1114{
1115 auto const_as_gamma = dynamic_cast<luci::CircleConst *>(_inst_norm->gamma());
1116 if (const_as_gamma == nullptr)
1117 return false;
1118
1119 auto C = input_channel();
1120 if (C == 0)
1121 return false;
1122
1123 auto new_const_as_gamma = match_const_channel(const_as_gamma, C);
1124 if (new_const_as_gamma == nullptr)
1125 return false;
1126
1127 _inst_norm->gamma(new_const_as_gamma);
1128
1129 return true;
1130}
1131
1135bool PostFusion::match_const_beta_channel(void)
1136{
1137 auto const_as_beta = dynamic_cast<luci::CircleConst *>(_inst_norm->beta());
1138 if (const_as_beta == nullptr)
1139 return false;
1140
1141 auto C = input_channel();
1142 if (C == 0)
1143 return false;
1144
1145 auto new_const_as_beta = match_const_channel(const_as_beta, C);
1146 if (new_const_as_beta == nullptr)
1147 return false;
1148
1149 _inst_norm->beta(new_const_as_beta);
1150
1151 return true;
1152}
1153
1154bool PostFusion::process(void)
1155{
1156 bool changed = false;
1157
1158 if (match_const_gamma_channel())
1159 changed = true;
1160 if (match_const_beta_channel())
1161 changed = true;
1162
1163 return changed;
1164}
1165
1166} // namespace
1167
1168namespace
1169{
1170
1171bool is_add_input_mul_const(luci::CircleAdd *add)
1172{
1173 luci::CircleMul *p_mul = nullptr;
1174 luci::CircleConst *p_const = nullptr;
1175
1176 return luci::fill(&p_mul, &p_const).with_commutative_args_of(add);
1177}
1178
1179bool is_add_input_mul_sub3d(luci::CircleAdd *add)
1180{
1181 luci::CircleMul *p_mul = nullptr;
1182 luci::CircleSub *p_sub = nullptr;
1183
1184 if (!luci::fill(&p_mul, &p_sub).with_commutative_args_of(add))
1185 return false;
1186
1187 auto sub = dynamic_cast<luci::CircleSub *>(add->y());
1188 if (sub == nullptr)
1189 return false;
1190
1191 auto const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
1192 if (const_as_beta == nullptr || const_as_beta->rank() != 3)
1193 return false;
1194
1195 return true;
1196}
1197
1198bool fuse_instance_norm(luci::CircleAdd *add)
1199{
1200 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1;
1201
1202 if (is_add_input_mul_const(add))
1203 pv = InstanceNormPattern::PatternVersion::Version_2;
1204 else if (is_add_input_mul_sub3d(add))
1205 pv = InstanceNormPattern::PatternVersion::Version_6;
1206
1207 InstanceNormPattern pattern(add, pv);
1208 if (pattern.matched())
1209 {
1210 FuseInstanceNorm fuse(pattern);
1211 fuse.apply();
1212 return true;
1213 }
1214
1215 if (pv == InstanceNormPattern::PatternVersion::Version_1)
1216 {
1217 // if Version_1 failed, try with Version_5
1218 pv = InstanceNormPattern::PatternVersion::Version_5;
1219 InstanceNormPattern pattern(add, pv);
1220 if (pattern.matched())
1221 {
1222 FuseInstanceNorm fuse(pattern);
1223 fuse.apply();
1224 return true;
1225 }
1226 }
1227 else if (pv == InstanceNormPattern::PatternVersion::Version_2)
1228 {
1229 // if Version_2 failed, try with Version_3
1230 pv = InstanceNormPattern::PatternVersion::Version_3;
1231 InstanceNormPattern pattern(add, pv);
1232 if (pattern.matched())
1233 {
1234 FuseInstanceNorm fuse(pattern);
1235 fuse.apply();
1236 return true;
1237 }
1238 }
1239
1240 return false;
1241}
1242
1243bool fuse_instance_norm(luci::CircleDiv *div)
1244{
1245 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_4;
1246
1247 InstanceNormPattern pattern(div, pv);
1248 if (pattern.matched())
1249 {
1250 FuseInstanceNorm fuse(pattern);
1251 fuse.apply();
1252 return true;
1253 }
1254
1255 return false;
1256}
1257
1258bool post_fusion(luci::CircleInstanceNorm *inst_norm)
1259{
1260 PostFusion postfusion(inst_norm);
1261
1262 return postfusion.process();
1263}
1264
1265} // namespace
1266
1267namespace luci
1268{
1269
1271{
1272 bool changed = false;
1273
1274 // Check Version_1, Version_2, Version_3, Version_5, Version_6
1275 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1276 {
1277 auto add = dynamic_cast<luci::CircleAdd *>(node);
1278 if (not add)
1279 continue;
1280
1281 if (fuse_instance_norm(add))
1282 changed = true;
1283 }
1284
1285 // Check Version_4(from DIV) if MUL-ADD pattern is not found
1286 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1287 {
1288 auto div = dynamic_cast<luci::CircleDiv *>(node);
1289 if (not div)
1290 continue;
1291
1292 if (fuse_instance_norm(div))
1293 changed = true;
1294 }
1295
1296 // Post processing of FuseInstanceNorm
1297 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1298 {
1299 auto inst_norm = dynamic_cast<luci::CircleInstanceNorm *>(node);
1300 if (not inst_norm)
1301 continue;
1302
1303 if (post_fusion(inst_norm))
1304 changed = true;
1305 }
1306
1307 return changed;
1308}
1309
1310} // namespace luci
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:54
Graph * graph(void)
Definition Node.h:70
void with(Node *into) const
Definition Node.cpp:66
ADD in Circle.
Definition CircleAdd.h:34
Class to build tensor data.
Definition CircleConst.h:35
const loco::DataTypeImpl< DT >::Type & at(uint32_t n) const
uint32_t size(void) const
DIV in Circle.
Definition CircleDiv.h:37
INSTANCE_NORM in Circle.
loco::Node * input(void) const
MEAN in Circle.
Definition CircleMean.h:32
MUL in Circle.
Definition CircleMul.h:34
POW in Circle.
Definition CirclePow.h:32
RSQRT in Circle.
Definition CircleRsqrt.h:32
SQUARE in Circle.
SQUARED_DIFFERENCE in Circle.
SUB in Circle.
Definition CircleSub.h:34
#define CHECK_OR_FALSE(condition)
bool is_instance_mean_v1(luci::CircleMean *mean)
bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
bool is_instance_mean_v2(luci::CircleMean *mean)
bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
C
Definition infer.py:52
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
Definition NodeFiller.h:72
CircleNode * clone_node(const CircleNode *node, loco::Graph *graph)
Return a new cloned CircleNode object with same attributes value of node to graph.
luci::CircleConst * clone(luci::CircleConst *node)
Return cloned object of CircleConst node.
T square(T value)
Definition Loss.h:37
Configuration p
bool run(loco::Graph *g) final
Run the pass.

Function Documentation

◆ is_1D_float32_const()

bool is_1D_float32_const ( const luci::CircleConst node,
uint32_t  channel_size 
)
Returns
true When node has the shape of 1D channel_size

Definition at line 129 of file FuseInstanceNormPass.cpp.

130{
131 if (node->rank() != 1)
132 return false;
133
134 if (node->dim(0).value() != channel_size)
135 return false;
136
137 if (node->dtype() != loco::DataType::FLOAT32)
138 return false;
139
140 if (node->size<loco::DataType::FLOAT32>() != channel_size)
141 return false;
142
143 return true;
144}

References luci::CircleConst::size().

◆ is_1D_with_dummy_dim()

bool is_1D_with_dummy_dim ( luci::CircleConst node,
uint32_t  depth 
)
Returns
true When node has shape of '1 x .. x 1 x depth'

Definition at line 32 of file FuseInstanceNormPass.cpp.

33{
34 auto rank = node->rank();
35 uint32_t axis;
36 for (axis = 0; axis < rank - 1; ++axis)
37 {
38 if (node->dim(axis).value() != 1)
39 return false;
40 }
41 return node->dim(axis).value() == depth;
42}

◆ is_instance_mean_v1()

bool is_instance_mean_v1 ( luci::CircleMean mean)

Definition at line 44 of file FuseInstanceNormPass.cpp.

45{
46 //
47 // CHECK 1) input is rank 4
48 //
49 auto input = loco::must_cast<luci::CircleNode *>(mean->input());
50 if (input->shape_status() != luci::ShapeStatus::VALID)
51 return false;
52 if (input->rank() != 4)
53 return false;
54
55 //
56 // CHECK 2) 'reduction indices' is CircleConst of value [1,2], that is HW of NHWC
57 //
58 // TODO Support equivalent case, like [-3,-2]
59 // TODO Support non-Const case?
60 // TODO What if input is NCHW format in Circle?
61 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
62 if (not red_indices)
63 return false;
64 if (red_indices->rank() != 1)
65 return false;
66 std::set<int32_t> red_indices_set;
67 {
68 // TODO Currently only support S32, support other types
69 assert(red_indices->dtype() == loco::DataType::S32);
70 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
71 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
72 }
73 if (red_indices_set.size() != 2)
74 return false;
75 if (red_indices_set.find(1) == red_indices_set.end())
76 return false;
77 if (red_indices_set.find(2) == red_indices_set.end())
78 return false;
79
80 //
81 // CHECK 3) keep_dims == true (?)
82 //
83 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
84 // TODO Check this fact, and if true, return true regardless of keep_dims
85 return mean->keep_dims();
86}
bool keep_dims(void) const
Definition CircleMean.h:41
loco::Node * input(void) const
Definition CircleMean.h:34
loco::Node * reduction_indices(void) const
Definition CircleMean.h:37

References luci::CircleMean::input(), luci::CircleMean::keep_dims(), luci::CircleMean::reduction_indices(), and luci::VALID.

◆ is_instance_mean_v2()

bool is_instance_mean_v2 ( luci::CircleMean mean)

Definition at line 88 of file FuseInstanceNormPass.cpp.

89{
90 //
91 // CHECK 1) input is rank 3
92 //
93 auto input = loco::must_cast<luci::CircleNode *>(mean->input());
94 if (input->shape_status() != luci::ShapeStatus::VALID)
95 return false;
96 if (input->rank() != 3)
97 return false;
98
99 //
100 // CHECK 2) 'reduction indices' is CircleConst of value [2], that is last dim of rank 3
101 //
102 // TODO Support non-Const case?
103 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
104 if (not red_indices)
105 return false;
106 if (red_indices->rank() != 1)
107 return false;
108 std::set<int32_t> red_indices_set;
109 {
110 // TODO Currently only support S32, support other types
111 assert(red_indices->dtype() == loco::DataType::S32);
112 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
113 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
114 }
115 if (red_indices_set.size() != 1)
116 return false;
117 if (red_indices_set.find(2) == red_indices_set.end())
118 return false;
119
120 //
121 // CHECK 3) keep_dims == true (?)
122 //
123 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
124 // TODO Check this fact, and if true, return true regardless of keep_dims
125 return mean->keep_dims();
126}

References luci::CircleMean::input(), luci::CircleMean::keep_dims(), luci::CircleMean::reduction_indices(), and luci::VALID.