ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FuseInstanceNormPass.cpp File Reference

Go to the source code of this file.

Namespaces

namespace  luci
 

Macros

#define CHECK_OR_FALSE(condition)
 

Functions

bool is_1D_with_dummy_dim (luci::CircleConst *node, uint32_t depth)
 
bool is_instance_mean_v1 (luci::CircleMean *mean)
 
bool is_instance_mean_v2 (luci::CircleMean *mean)
 
bool is_1D_float32_const (const luci::CircleConst *node, uint32_t channel_size)
 

Macro Definition Documentation

◆ CHECK_OR_FALSE

#define CHECK_OR_FALSE (   condition)
Value:
if (not(condition)) \
return false;

Definition at line 446 of file FuseInstanceNormPass.cpp.

450{
451 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
452 CHECK_OR_FALSE(add_as_variance);
453
455 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
456
457 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
458 // TODO Support regarding broadcast
459 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
460
461 CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
462
463 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
464 CHECK_OR_FALSE(sqdiff);
465
466 loco::Node *ifm_should_be = nullptr;
467 CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
468 CHECK_OR_FALSE(ifm == ifm_should_be);
470 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
471
472 const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
473 CHECK_OR_FALSE(const_as_beta);
474 CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
475
476 return true;
477}
478
479bool InstanceNormPattern::condition_common_3_4()
480{
481 // check left sub
482 ifm = sub->x();
483 CHECK_OR_FALSE(ifm);
484
485 luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
486 CHECK_OR_FALSE(ifm_node->rank() == 4);
487 CHECK_OR_FALSE(ifm_node->dim(3).known());
488
489 mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
490 CHECK_OR_FALSE(mean_of_ifm);
491 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
492
493 // continue search from add_as_variance
494 CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance));
495 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
496 // TODO Support regarding broadcast
497 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
498
499 mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x());
500 CHECK_OR_FALSE(mean_as_variance);
501
502 square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
503 CHECK_OR_FALSE(square);
504
505 sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
506 CHECK_OR_FALSE(sub_2);
507 CHECK_OR_FALSE(ifm == sub_2->x());
508
509 mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y());
510 CHECK_OR_FALSE(mean_of_ifm_2);
511 CHECK_OR_FALSE(ifm == mean_of_ifm_2->input());
512
513 loco::Node *ifm_should_be = nullptr;
514 luci::CircleMean *mean_of_ifm_2_should_be = nullptr;
516 luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
517 CHECK_OR_FALSE(ifm == ifm_should_be);
518 CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be);
519
520 return true;
521}
522
523template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
524{
525 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
526 CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
527
528 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
529 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
530 CHECK_OR_FALSE(ifm_circle->rank() == 4);
531 CHECK_OR_FALSE(ifm_circle->dim(3).known());
532 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
533
534 CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
535
536 CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
537
538 CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
539
540 luci::CircleMul *mul_gamma_should_be = nullptr;
541 luci::CircleMean *mean_of_ifm_should_be = nullptr;
542
543 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
544 CHECK_OR_FALSE(mul_as_scaled_mean);
545 CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
546 .with_commutative_args_of(mul_as_scaled_mean));
547 CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
548 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
549
550 _matched = true;
551 return true;
552}
553
554template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_2>()
555{
556 CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
557 CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
558
559 sub = dynamic_cast<luci::CircleSub *>(div->x());
560 CHECK_OR_FALSE(sub);
561
562 ifm = sub->x();
563 CHECK_OR_FALSE(ifm);
564
565 luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
566 CHECK_OR_FALSE(ifm_node->rank() == 4);
567 CHECK_OR_FALSE(ifm_node->dim(3).known());
568 uint32_t ifm_channel_depth = ifm_node->dim(3).value();
569
570 mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
571 CHECK_OR_FALSE(mean_of_ifm);
572
573 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
574
575 pow = dynamic_cast<luci::CirclePow *>(div->y());
576 CHECK_OR_FALSE(pow);
577
578 add_as_variance = dynamic_cast<luci::CircleAdd *>(pow->x());
579 CHECK_OR_FALSE(add_as_variance);
580
581 luci::CircleConst *zero_point_five = dynamic_cast<luci::CircleConst *>(pow->y());
582 CHECK_OR_FALSE(zero_point_five);
583 CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
584 // TODO Support regarding broadcast
585 CHECK_OR_FALSE(zero_point_five->size<loco::DataType::FLOAT32>() == 1);
586 CHECK_OR_FALSE(zero_point_five->at<loco::DataType::FLOAT32>(0) == 0.5);
587
589 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
590 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
591 // TODO Support regarding broadcast
592 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
593
594 CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
595
596 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
597 CHECK_OR_FALSE(sqdiff);
598
599 loco::Node *ifm_should_be = nullptr;
600 luci::CircleMean *mean_of_ifm_should_be = nullptr;
602 luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
603 CHECK_OR_FALSE(ifm == ifm_should_be);
604 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
605
606 // Check for channel size
607 CHECK_OR_FALSE(is_1D_float32_const(const_as_gamma, ifm_channel_depth));
608 CHECK_OR_FALSE(is_1D_float32_const(const_as_beta, ifm_channel_depth));
609
610 _matched = true;
611 return true;
612}
613
614template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_3>()
615{
616 CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
617 CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
618 CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
619
620 CHECK_OR_FALSE(condition_common_3_4());
621
622 _matched = true;
623 return true;
624}
625
626luci::CircleConst *make_const_one(loco::Graph *graph, float value)
627{
628 auto const_one = graph->nodes()->create<luci::CircleConst>();
629 const_one->dtype(loco::DataType::FLOAT32);
630 const_one->rank(1);
631 const_one->size<loco::DataType::FLOAT32>(1);
632 const_one->at<loco::DataType::FLOAT32>(0) = value;
633 return const_one;
634}
635
636template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_4>()
637{
638 CHECK_OR_FALSE(div);
639 CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
640
641 CHECK_OR_FALSE(condition_common_3_4());
642
643 assert(const_as_gamma == nullptr);
644 assert(const_as_beta == nullptr);
645 assert(mul_gamma == nullptr);
646 assert(add_as_terminal == nullptr);
647
648 // create 1.0 gamma and 0.0 beta
649 auto graph = div->graph();
650 const_as_gamma = make_const_one(graph, 1.0f);
651 const_as_beta = make_const_one(graph, 0.0f);
652 const_as_gamma->name(div->name() + "/gamma");
653 const_as_beta->name(div->name() + "/beta");
654
655 _matched = true;
656 return true;
657}
658
659template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
660{
661 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
662 CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
663
664 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
665 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
666 CHECK_OR_FALSE(ifm_circle->rank() == 4);
667 CHECK_OR_FALSE(ifm_circle->dim(3).known());
668 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
669
670 CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
671
672 luci::CircleRsqrt *rsqrt_should_be = nullptr;
673 luci::CircleMean *mean_of_ifm_should_be = nullptr;
674
675 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
676 CHECK_OR_FALSE(mul_as_scaled_mean);
677 CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
678 .with_commutative_args_of(mul_as_scaled_mean));
679 CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
680 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
681
682 // mul_gamma is absent
683 // const_as_gamma assume to be 1.0
684 auto graph = add_as_terminal->graph();
685 const_as_gamma = make_const_one(graph, 1.0f);
686 const_as_gamma->name(add_as_terminal->name() + "/gamma");
687
688 _matched = true;
689 return true;
690}
691
692template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_6>()
693{
694 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
695 CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
696
697 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
698 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
699 CHECK_OR_FALSE(ifm_circle->rank() == 3);
700 CHECK_OR_FALSE((ifm_circle->dim(1).known()));
701
702 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
703 CHECK_OR_FALSE(add_as_variance);
704
706 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
707
708 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
709 // TODO Support regarding broadcast
710 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
711
712 CHECK_OR_FALSE(is_instance_mean_v2(mean_as_variance));
713
714 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
715 CHECK_OR_FALSE(sqdiff);
716
717 loco::Node *ifm_should_be = nullptr;
718 CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
719 CHECK_OR_FALSE(ifm == ifm_should_be);
721 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
722
723 // If const_as_beta has shape of '1 x chennel x (1 or input last dimension)'
724 uint32_t input_channel = ifm_circle->dim(1).value();
725 uint32_t input_last_dim = ifm_circle->dim(2).value();
726 const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
727 CHECK_OR_FALSE(const_as_beta);
728 CHECK_OR_FALSE(const_as_beta->rank() == 3);
730 const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel &&
731 (const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim));
732
733 luci::CircleRsqrt *rsqrt_should_be = nullptr;
734 luci::CircleMean *mean_of_ifm_should_be = nullptr;
735
736 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
737 CHECK_OR_FALSE(mul_as_scaled_mean);
738 CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
739 .with_commutative_args_of(mul_as_scaled_mean));
740 CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
741 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
742
743 // mul_gamma is absent
744 // const_as_gamma assume to be 1.0
745 auto graph = add_as_terminal->graph();
746 const_as_gamma = make_const_one(graph, 1.0f);
747 const_as_gamma->name(add_as_terminal->name() + "/gamma");
748
749 _matched = true;
750 return true;
751}
752
753bool InstanceNormPattern::matched()
754{
755 if (_matched)
756 return true;
757
758 // Check order is DFS
759
760 switch (_pv)
761 {
762 case PatternVersion::Version_1:
763 return match<PatternVersion::Version_1>();
764 case PatternVersion::Version_2:
765 return match<PatternVersion::Version_2>();
766 case PatternVersion::Version_3:
767 return match<PatternVersion::Version_3>();
768 case PatternVersion::Version_4:
769 return match<PatternVersion::Version_4>();
770 case PatternVersion::Version_5:
771 return match<PatternVersion::Version_5>();
772 case PatternVersion::Version_6:
773 return match<PatternVersion::Version_6>();
774
775 default:
776 break;
777 }
778
779 throw std::runtime_error("Invalid InstanceNorm PatternVersion.");
780}
781
782#undef CHECK_OR_FALSE
783
800class FuseInstanceNorm final
801{
802public:
803 FuseInstanceNorm(const InstanceNormPattern &p) : _p(p) {}
804
805public:
806 void apply(void);
807
808private:
809 template <InstanceNormPattern::PatternVersion> void apply(void);
810
811private:
812 void reshape_gamma_beta(void);
813 luci::CircleInstanceNorm *create_inst_norm(loco::Graph *graph);
814
815private:
816 const InstanceNormPattern &_p;
817};
818
819void FuseInstanceNorm::reshape_gamma_beta()
820{
821 // Version 1 and 3 need to reshape
822 {
823 _p.const_as_gamma->rank(1);
824 _p.const_as_gamma->dim(0).set(_p.const_as_gamma->size<loco::DataType::FLOAT32>());
825 _p.const_as_beta->rank(1);
826 _p.const_as_beta->dim(0).set(_p.const_as_beta->size<loco::DataType::FLOAT32>());
827
828 _p.const_as_gamma->shape_status(luci::ShapeStatus::UNDEFINED);
829 _p.const_as_beta->shape_status(luci::ShapeStatus::UNDEFINED);
830 }
831}
832
833luci::CircleInstanceNorm *FuseInstanceNorm::create_inst_norm(loco::Graph *graph)
834{
835 // Make Instance Norm to replace
836 auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>();
837 instance_norm->input(_p.ifm);
838 instance_norm->gamma(_p.const_as_gamma);
839 instance_norm->beta(_p.const_as_beta);
840 float epsilon = _p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
841 instance_norm->epsilon(epsilon);
842 if (_p.add_as_terminal != nullptr)
843 {
844 instance_norm->fusedActivationFunction(_p.add_as_terminal->fusedActivationFunction());
845 // NOTE unique name should be assigned in export
846 instance_norm->name("FusedInstanceNorm/" + _p.add_as_terminal->name());
847 }
848 else
849 {
850 // VERSION_4
851 assert(_p.div != nullptr);
852 instance_norm->fusedActivationFunction(_p.div->fusedActivationFunction());
853 instance_norm->name("FusedInstanceNorm/" + _p.div->name());
854 }
855
856 return instance_norm;
857}
858
859template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_1>()
860{
861 auto graph = _p.add_as_terminal->graph();
862
863 reshape_gamma_beta();
864
865 auto instance_norm = create_inst_norm(graph);
866
867 // set origin
868 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
869 luci::get_origin(_p.mean_of_ifm),
870 luci::get_origin(_p.sqdiff),
871 luci::get_origin(_p.mean_as_variance),
872 luci::get_origin(_p.add_as_variance),
873 luci::get_origin(_p.rsqrt),
874 luci::get_origin(_p.mul_gamma),
875 luci::get_origin(_p.mul_as_scaled_ifm),
876 luci::get_origin(_p.mul_as_scaled_mean),
877 luci::get_origin(_p.sub),
878 luci::get_origin(_p.add_as_terminal)};
879
880 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
881
882 replace(_p.add_as_terminal).with(instance_norm);
883}
884
885template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_2>()
886{
887 auto graph = _p.add_as_terminal->graph();
888
889 auto instance_norm = create_inst_norm(graph);
890
891 // set origin
892 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
893 luci::get_origin(_p.mean_of_ifm),
894 luci::get_origin(_p.sqdiff),
895 luci::get_origin(_p.mean_as_variance),
896 luci::get_origin(_p.add_as_variance),
897 luci::get_origin(_p.pow),
898 luci::get_origin(_p.sub),
899 luci::get_origin(_p.div),
900 luci::get_origin(_p.mul_gamma),
901 luci::get_origin(_p.add_as_terminal)};
902
903 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
904
905 replace(_p.add_as_terminal).with(instance_norm);
906}
907
908template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_3>()
909{
910 auto graph = _p.add_as_terminal->graph();
911
912 reshape_gamma_beta();
913
914 auto instance_norm = create_inst_norm(graph);
915
916 // set origin
917 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
918 luci::get_origin(_p.mean_of_ifm),
919 luci::get_origin(_p.sub),
920 luci::get_origin(_p.mean_of_ifm_2),
921 luci::get_origin(_p.sub_2),
922 luci::get_origin(_p.square),
923 luci::get_origin(_p.mean_as_variance),
924 luci::get_origin(_p.sqrt),
925 luci::get_origin(_p.add_as_variance),
926 luci::get_origin(_p.div),
927 luci::get_origin(_p.mul_gamma),
928 luci::get_origin(_p.add_as_terminal)};
929
930 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
931
932 replace(_p.add_as_terminal).with(instance_norm);
933}
934
935template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_4>()
936{
937 auto graph = _p.div->graph();
938
939 auto instance_norm = create_inst_norm(graph);
940
941 // set origin
942 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
943 luci::get_origin(_p.mean_of_ifm),
944 luci::get_origin(_p.sub),
945 luci::get_origin(_p.mean_of_ifm_2),
946 luci::get_origin(_p.sub_2),
947 luci::get_origin(_p.square),
948 luci::get_origin(_p.mean_as_variance),
949 luci::get_origin(_p.sqrt),
950 luci::get_origin(_p.add_as_variance),
951 luci::get_origin(_p.div)};
952
953 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
954
955 replace(_p.div).with(instance_norm);
956}
957
958template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
959{
960 auto graph = _p.add_as_terminal->graph();
961
962 reshape_gamma_beta();
963
964 auto instance_norm = create_inst_norm(graph);
965
966 // set origin
967 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
968 luci::get_origin(_p.mean_of_ifm),
969 luci::get_origin(_p.sqdiff),
970 luci::get_origin(_p.mean_as_variance),
971 luci::get_origin(_p.add_as_variance),
972 luci::get_origin(_p.rsqrt),
973 luci::get_origin(_p.mul_as_scaled_ifm),
974 luci::get_origin(_p.mul_as_scaled_mean),
975 luci::get_origin(_p.sub),
976 luci::get_origin(_p.add_as_terminal)};
977
978 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
979
980 replace(_p.add_as_terminal).with(instance_norm);
981}
982
983template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_6>()
984{
985 auto graph = _p.add_as_terminal->graph();
986
987 reshape_gamma_beta();
988
989 auto instance_norm = create_inst_norm(graph);
990
991 // set origin
992 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
993 luci::get_origin(_p.mean_of_ifm),
994 luci::get_origin(_p.sqdiff),
995 luci::get_origin(_p.mean_as_variance),
996 luci::get_origin(_p.add_as_variance),
997 luci::get_origin(_p.rsqrt),
998 luci::get_origin(_p.mul_as_scaled_ifm),
999 luci::get_origin(_p.mul_as_scaled_mean),
1000 luci::get_origin(_p.sub),
1001 luci::get_origin(_p.add_as_terminal)};
1002
1003 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1004
1005 replace(_p.add_as_terminal).with(instance_norm);
1006}
1007
1008void FuseInstanceNorm::apply()
1009{
1010 assert(_p.matched());
1011
1012 switch (_p.version())
1013 {
1014 case InstanceNormPattern::PatternVersion::Version_1:
1015 apply<InstanceNormPattern::PatternVersion::Version_1>();
1016 break;
1017 case InstanceNormPattern::PatternVersion::Version_2:
1018 apply<InstanceNormPattern::PatternVersion::Version_2>();
1019 break;
1020 case InstanceNormPattern::PatternVersion::Version_3:
1021 apply<InstanceNormPattern::PatternVersion::Version_3>();
1022 break;
1023 case InstanceNormPattern::PatternVersion::Version_4:
1024 apply<InstanceNormPattern::PatternVersion::Version_4>();
1025 break;
1026 case InstanceNormPattern::PatternVersion::Version_5:
1027 apply<InstanceNormPattern::PatternVersion::Version_5>();
1028 break;
1029 case InstanceNormPattern::PatternVersion::Version_6:
1030 apply<InstanceNormPattern::PatternVersion::Version_6>();
1031 break;
1032
1033 default:
1034 break;
1035 }
1036}
1037
1038} // namespace
1039
1040namespace
1041{
1042
1043class PostFusion final
1044{
1045public:
1046 PostFusion(luci::CircleInstanceNorm *inst_norm) : _inst_norm(inst_norm) {}
1047
1048private:
1049 uint32_t input_channel(void);
1050
1051 luci::CircleConst *match_const_channel(luci::CircleConst *, uint32_t);
1052 bool match_const_gamma_channel(void);
1053 bool match_const_beta_channel(void);
1054
1055public:
1056 bool process(void);
1057
1058private:
1059 luci::CircleInstanceNorm *_inst_norm = nullptr;
1060};
1061
1065uint32_t PostFusion::input_channel(void)
1066{
1067 auto input = dynamic_cast<luci::CircleNode *>(_inst_norm->input());
1068 if (input == nullptr)
1069 return 0;
1070 if (input->shape_status() != luci::ShapeStatus::VALID)
1071 return 0;
1072
1073 auto input_rank = input->rank();
1074 if (input_rank < 1)
1075 return 0;
1076
1077 if (input_rank == 3)
1078 {
1079 // use dim 1
1080 return input->dim(1).value();
1081 }
1082 // assume channel-last
1083 return input->dim(input_rank - 1).value();
1084}
1085
1089luci::CircleConst *PostFusion::match_const_channel(luci::CircleConst *input_const, uint32_t C)
1090{
1091 luci::CircleConst *new_input_const = nullptr;
1092
1093 auto input_chn = input_const->dim(0).value();
1094 if (input_chn == 1 && input_chn != C)
1095 {
1096 float value = input_const->at<loco::DataType::FLOAT32>(0);
1097 auto clone = luci::clone_node(input_const, input_const->graph());
1098
1099 new_input_const = loco::must_cast<luci::CircleConst *>(clone);
1100 new_input_const->rank(1);
1101 new_input_const->dim(0).set(C);
1102 new_input_const->size<loco::DataType::FLOAT32>(C);
1103 for (uint32_t c = 0; c < C; ++c)
1104 new_input_const->at<loco::DataType::FLOAT32>(c) = value;
1105 }
1106
1107 return new_input_const;
1108}
1109
1113bool PostFusion::match_const_gamma_channel(void)
1114{
1115 auto const_as_gamma = dynamic_cast<luci::CircleConst *>(_inst_norm->gamma());
1116 if (const_as_gamma == nullptr)
1117 return false;
1118
1119 auto C = input_channel();
1120 if (C == 0)
1121 return false;
1122
1123 auto new_const_as_gamma = match_const_channel(const_as_gamma, C);
1124 if (new_const_as_gamma == nullptr)
1125 return false;
1126
1127 _inst_norm->gamma(new_const_as_gamma);
1128
1129 return true;
1130}
1131
1135bool PostFusion::match_const_beta_channel(void)
1136{
1137 auto const_as_beta = dynamic_cast<luci::CircleConst *>(_inst_norm->beta());
1138 if (const_as_beta == nullptr)
1139 return false;
1140
1141 auto C = input_channel();
1142 if (C == 0)
1143 return false;
1144
1145 auto new_const_as_beta = match_const_channel(const_as_beta, C);
1146 if (new_const_as_beta == nullptr)
1147 return false;
1148
1149 _inst_norm->beta(new_const_as_beta);
1150
1151 return true;
1152}
1153
1154bool PostFusion::process(void)
1155{
1156 bool changed = false;
1157
1158 if (match_const_gamma_channel())
1159 changed = true;
1160 if (match_const_beta_channel())
1161 changed = true;
1162
1163 return changed;
1164}
1165
1166} // namespace
1167
1168namespace
1169{
1170
1171bool is_add_input_mul_const(luci::CircleAdd *add)
1172{
1173 luci::CircleMul *p_mul = nullptr;
1174 luci::CircleConst *p_const = nullptr;
1175
1176 return luci::fill(&p_mul, &p_const).with_commutative_args_of(add);
1177}
1178
1179bool is_add_input_mul_sub3d(luci::CircleAdd *add)
1180{
1181 luci::CircleMul *p_mul = nullptr;
1182 luci::CircleSub *p_sub = nullptr;
1183
1184 if (!luci::fill(&p_mul, &p_sub).with_commutative_args_of(add))
1185 return false;
1186
1187 auto sub = dynamic_cast<luci::CircleSub *>(add->y());
1188 if (sub == nullptr)
1189 return false;
1190
1191 auto const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
1192 if (const_as_beta == nullptr || const_as_beta->rank() != 3)
1193 return false;
1194
1195 return true;
1196}
1197
1198bool fuse_instance_norm(luci::CircleAdd *add)
1199{
1200 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1;
1201
1202 if (is_add_input_mul_const(add))
1203 pv = InstanceNormPattern::PatternVersion::Version_2;
1204 else if (is_add_input_mul_sub3d(add))
1205 pv = InstanceNormPattern::PatternVersion::Version_6;
1206
1207 InstanceNormPattern pattern(add, pv);
1208 if (pattern.matched())
1209 {
1210 FuseInstanceNorm fuse(pattern);
1211 fuse.apply();
1212 return true;
1213 }
1214
1215 if (pv == InstanceNormPattern::PatternVersion::Version_1)
1216 {
1217 // if Version_1 failed, try with Version_5
1218 pv = InstanceNormPattern::PatternVersion::Version_5;
1219 InstanceNormPattern pattern(add, pv);
1220 if (pattern.matched())
1221 {
1222 FuseInstanceNorm fuse(pattern);
1223 fuse.apply();
1224 return true;
1225 }
1226 }
1227 else if (pv == InstanceNormPattern::PatternVersion::Version_2)
1228 {
1229 // if Version_2 failed, try with Version_3
1230 pv = InstanceNormPattern::PatternVersion::Version_3;
1231 InstanceNormPattern pattern(add, pv);
1232 if (pattern.matched())
1233 {
1234 FuseInstanceNorm fuse(pattern);
1235 fuse.apply();
1236 return true;
1237 }
1238 }
1239
1240 return false;
1241}
1242
1243bool fuse_instance_norm(luci::CircleDiv *div)
1244{
1245 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_4;
1246
1247 InstanceNormPattern pattern(div, pv);
1248 if (pattern.matched())
1249 {
1250 FuseInstanceNorm fuse(pattern);
1251 fuse.apply();
1252 return true;
1253 }
1254
1255 return false;
1256}
1257
1258bool post_fusion(luci::CircleInstanceNorm *inst_norm)
1259{
1260 PostFusion postfusion(inst_norm);
1261
1262 return postfusion.process();
1263}
1264
1265} // namespace
1266
1267namespace luci
1268{
1269
1271{
1272 bool changed = false;
1273
1274 // Check Version_1, Version_2, Version_3, Version_5, Version_6
1275 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1276 {
1277 auto add = dynamic_cast<luci::CircleAdd *>(node);
1278 if (not add)
1279 continue;
1280
1281 if (fuse_instance_norm(add))
1282 changed = true;
1283 }
1284
1285 // Check Version_4(from DIV) if MUL-ADD pattern is not found
1286 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1287 {
1288 auto div = dynamic_cast<luci::CircleDiv *>(node);
1289 if (not div)
1290 continue;
1291
1292 if (fuse_instance_norm(div))
1293 changed = true;
1294 }
1295
1296 // Post processing of FuseInstanceNorm
1297 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1298 {
1299 auto inst_norm = dynamic_cast<luci::CircleInstanceNorm *>(node);
1300 if (not inst_norm)
1301 continue;
1302
1303 if (post_fusion(inst_norm))
1304 changed = true;
1305 }
1306
1307 return changed;
1308}
1309
1310} // namespace luci
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:54
Graph * graph(void)
Definition Node.h:70
void with(Node *into) const
Definition Node.cpp:66
ADD in Circle.
Definition CircleAdd.h:34
Class to build tensor data.
Definition CircleConst.h:35
const loco::DataTypeImpl< DT >::Type & at(uint32_t n) const
uint32_t size(void) const
DIV in Circle.
Definition CircleDiv.h:37
INSTANCE_NORM in Circle.
loco::Node * input(void) const
MEAN in Circle.
Definition CircleMean.h:32
MUL in Circle.
Definition CircleMul.h:34
POW in Circle.
Definition CirclePow.h:32
RSQRT in Circle.
Definition CircleRsqrt.h:32
SQUARE in Circle.
SQUARED_DIFFERENCE in Circle.
SUB in Circle.
Definition CircleSub.h:34
#define CHECK_OR_FALSE(condition)
bool is_instance_mean_v1(luci::CircleMean *mean)
bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
bool is_instance_mean_v2(luci::CircleMean *mean)
bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
C
Definition infer.py:52
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
Definition NodeFiller.h:72
CircleNode * clone_node(const CircleNode *node, loco::Graph *graph)
Return a new cloned CircleNode object with same attributes value of node to graph.
luci::CircleConst * clone(luci::CircleConst *node)
Return cloned object of CircleConst node.
T square(T value)
Definition Loss.h:37
bool run(loco::Graph *g) final
Run the pass.

Function Documentation

◆ is_1D_float32_const()

bool is_1D_float32_const ( const luci::CircleConst node,
uint32_t  channel_size 
)
Returns
true When node has the shape of 1D channel_size

Definition at line 129 of file FuseInstanceNormPass.cpp.

130{
131 if (node->rank() != 1)
132 return false;
133
134 if (node->dim(0).value() != channel_size)
135 return false;
136
137 if (node->dtype() != loco::DataType::FLOAT32)
138 return false;
139
140 if (node->size<loco::DataType::FLOAT32>() != channel_size)
141 return false;
142
143 return true;
144}

References luci::CircleConst::size().

◆ is_1D_with_dummy_dim()

bool is_1D_with_dummy_dim ( luci::CircleConst node,
uint32_t  depth 
)
Returns
true When node has shape of '1 x .. x 1 x depth'

Definition at line 32 of file FuseInstanceNormPass.cpp.

33{
34 auto rank = node->rank();
35 uint32_t axis;
36 for (axis = 0; axis < rank - 1; ++axis)
37 {
38 if (node->dim(axis).value() != 1)
39 return false;
40 }
41 return node->dim(axis).value() == depth;
42}

◆ is_instance_mean_v1()

bool is_instance_mean_v1 ( luci::CircleMean mean)

Definition at line 44 of file FuseInstanceNormPass.cpp.

45{
46 //
47 // CHECK 1) input is rank 4
48 //
49 auto input = loco::must_cast<luci::CircleNode *>(mean->input());
50 if (input->shape_status() != luci::ShapeStatus::VALID)
51 return false;
52 if (input->rank() != 4)
53 return false;
54
55 //
56 // CHECK 2) 'reduction indices' is CircleConst of value [1,2], that is HW of NHWC
57 //
58 // TODO Support equivalent case, like [-3,-2]
59 // TODO Support non-Const case?
60 // TODO What if input is NCHW format in Circle?
61 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
62 if (not red_indices)
63 return false;
64 if (red_indices->rank() != 1)
65 return false;
66 std::set<int32_t> red_indices_set;
67 {
68 // TODO Currently only support S32, support other types
69 assert(red_indices->dtype() == loco::DataType::S32);
70 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
71 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
72 }
73 if (red_indices_set.size() != 2)
74 return false;
75 if (red_indices_set.find(1) == red_indices_set.end())
76 return false;
77 if (red_indices_set.find(2) == red_indices_set.end())
78 return false;
79
80 //
81 // CHECK 3) keep_dims == true (?)
82 //
83 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
84 // TODO Check this fact, and if true, return true regardless of keep_dims
85 return mean->keep_dims();
86}
bool keep_dims(void) const
Definition CircleMean.h:41
loco::Node * input(void) const
Definition CircleMean.h:34
loco::Node * reduction_indices(void) const
Definition CircleMean.h:37

References luci::CircleMean::input(), luci::CircleMean::keep_dims(), luci::CircleMean::reduction_indices(), and luci::VALID.

◆ is_instance_mean_v2()

bool is_instance_mean_v2 ( luci::CircleMean mean)

Definition at line 88 of file FuseInstanceNormPass.cpp.

89{
90 //
91 // CHECK 1) input is rank 3
92 //
93 auto input = loco::must_cast<luci::CircleNode *>(mean->input());
94 if (input->shape_status() != luci::ShapeStatus::VALID)
95 return false;
96 if (input->rank() != 3)
97 return false;
98
99 //
100 // CHECK 2) 'reduction indices' is CircleConst of value [2], that is last dim of rank 3
101 //
102 // TODO Support non-Const case?
103 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
104 if (not red_indices)
105 return false;
106 if (red_indices->rank() != 1)
107 return false;
108 std::set<int32_t> red_indices_set;
109 {
110 // TODO Currently only support S32, support other types
111 assert(red_indices->dtype() == loco::DataType::S32);
112 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
113 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
114 }
115 if (red_indices_set.size() != 1)
116 return false;
117 if (red_indices_set.find(2) == red_indices_set.end())
118 return false;
119
120 //
121 // CHECK 3) keep_dims == true (?)
122 //
123 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
124 // TODO Check this fact, and if true, return true regardless of keep_dims
125 return mean->keep_dims();
126}

References luci::CircleMean::input(), luci::CircleMean::keep_dims(), luci::CircleMean::reduction_indices(), and luci::VALID.