ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FuseInstanceNormPass.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18#include "helpers/NodeFiller.h"
20
21#include <luci/IR/CircleNodes.h>
22
25
26#include <cassert>
27#include <set>
28
29// Helper to check detail
30
32bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
33{
34 auto rank = node->rank();
35 uint32_t axis;
36 for (axis = 0; axis < rank - 1; ++axis)
37 {
38 if (node->dim(axis).value() != 1)
39 return false;
40 }
41 return node->dim(axis).value() == depth;
42}
43
45{
46 //
47 // CHECK 1) input is rank 4
48 //
49 auto input = loco::must_cast<luci::CircleNode *>(mean->input());
50 if (input->shape_status() != luci::ShapeStatus::VALID)
51 return false;
52 if (input->rank() != 4)
53 return false;
54
55 //
56 // CHECK 2) 'reduction indices' is CircleConst of value [1,2], that is HW of NHWC
57 //
58 // TODO Support equivalent case, like [-3,-2]
59 // TODO Support non-Const case?
60 // TODO What if input is NCHW format in Circle?
61 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
62 if (not red_indices)
63 return false;
64 if (red_indices->rank() != 1)
65 return false;
66 std::set<int32_t> red_indices_set;
67 {
68 // TODO Currently only support S32, support other types
69 assert(red_indices->dtype() == loco::DataType::S32);
70 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
71 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
72 }
73 if (red_indices_set.size() != 2)
74 return false;
75 if (red_indices_set.find(1) == red_indices_set.end())
76 return false;
77 if (red_indices_set.find(2) == red_indices_set.end())
78 return false;
79
80 //
81 // CHECK 3) keep_dims == true (?)
82 //
83 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
84 // TODO Check this fact, and if true, return true regardless of keep_dims
85 return mean->keep_dims();
86}
87
89{
90 //
91 // CHECK 1) input is rank 3
92 //
93 auto input = loco::must_cast<luci::CircleNode *>(mean->input());
94 if (input->shape_status() != luci::ShapeStatus::VALID)
95 return false;
96 if (input->rank() != 3)
97 return false;
98
99 //
100 // CHECK 2) 'reduction indices' is CircleConst of value [2], that is last dim of rank 3
101 //
102 // TODO Support non-Const case?
103 auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
104 if (not red_indices)
105 return false;
106 if (red_indices->rank() != 1)
107 return false;
108 std::set<int32_t> red_indices_set;
109 {
110 // TODO Currently only support S32, support other types
111 assert(red_indices->dtype() == loco::DataType::S32);
112 for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
113 red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
114 }
115 if (red_indices_set.size() != 1)
116 return false;
117 if (red_indices_set.find(2) == red_indices_set.end())
118 return false;
119
120 //
121 // CHECK 3) keep_dims == true (?)
122 //
123 // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
124 // TODO Check this fact, and if true, return true regardless of keep_dims
125 return mean->keep_dims();
126}
127
129bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
130{
131 if (node->rank() != 1)
132 return false;
133
134 if (node->dim(0).value() != channel_size)
135 return false;
136
137 if (node->dtype() != loco::DataType::FLOAT32)
138 return false;
139
140 if (node->size<loco::DataType::FLOAT32>() != channel_size)
141 return false;
142
143 return true;
144}
145
146// Helper to fuse Instance Norm
147namespace
148{
149
374class InstanceNormPattern final
375{
376public:
377 enum PatternVersion
378 {
379 Version_Unknown,
380 Version_1,
381 Version_2,
382 Version_3,
383 Version_4,
384 Version_5,
385 Version_6, // For only 3D I/O
386 };
387
388 InstanceNormPattern(luci::CircleAdd *candidate, PatternVersion pv)
389 {
390 assert(candidate);
391 add_as_terminal = candidate;
392 _pv = pv;
393 }
394
395 InstanceNormPattern(luci::CircleDiv *candidate, PatternVersion pv)
396 {
397 assert(candidate);
398 div = candidate;
399 _pv = pv;
400 }
401
402private:
403 bool condition_common_1_5(uint32_t ifm_channel_depth);
404 bool condition_common_3_4();
405
406private:
407 template <enum PatternVersion> bool match();
408
409public:
410 bool matched();
411 bool matched() const { return _matched; }
412
413 PatternVersion version() const { return _pv; }
414
415public:
416 // Context
417 loco::Node *ifm = nullptr;
418 luci::CircleReshape *reshape_of_ifm = nullptr;
419 luci::CircleMean *mean_of_ifm = nullptr;
420 luci::CircleMean *mean_of_ifm_2 = nullptr;
421 luci::CircleMean *mean_of_reshape = nullptr;
422 luci::CircleSquaredDifference *sqdiff = nullptr;
423 luci::CircleSquare *square = nullptr;
424 luci::CircleMean *mean_as_variance = nullptr;
425 luci::CircleConst *const_as_epsilon = nullptr;
426 luci::CircleAdd *add_as_variance = nullptr;
427 luci::CircleRsqrt *rsqrt = nullptr;
428 luci::CircleConst *const_as_gamma = nullptr;
429 luci::CircleMul *mul_gamma = nullptr;
430 luci::CircleMul *mul_as_scaled_ifm = nullptr;
431 luci::CircleMul *mul_as_scaled_mean = nullptr;
432 luci::CircleMul *mul_as_scaled_reshape = nullptr;
433 luci::CircleConst *const_as_beta = nullptr;
434 luci::CircleSub *sub = nullptr;
435 luci::CircleSub *sub_2 = nullptr;
436 luci::CircleAdd *add_as_terminal = nullptr;
437 luci::CirclePow *pow = nullptr;
438 luci::CircleSqrt *sqrt = nullptr;
439 luci::CircleDiv *div = nullptr;
440
441private:
442 bool _matched = false;
443 PatternVersion _pv;
444};
445
446#define CHECK_OR_FALSE(condition) \
447 if (not(condition)) \
448 return false;
449
450bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
451{
452 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
453 CHECK_OR_FALSE(add_as_variance);
454
456 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
457
458 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
459 // TODO Support regarding broadcast
460 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
461
462 CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
463
464 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
465 CHECK_OR_FALSE(sqdiff);
466
467 loco::Node *ifm_should_be = nullptr;
468 CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
469 CHECK_OR_FALSE(ifm == ifm_should_be);
471 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
472
473 const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
474 CHECK_OR_FALSE(const_as_beta);
475 CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
476
477 return true;
478}
479
480bool InstanceNormPattern::condition_common_3_4()
481{
482 // check left sub
483 ifm = sub->x();
484 CHECK_OR_FALSE(ifm);
485
486 luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
487 CHECK_OR_FALSE(ifm_node->rank() == 4);
488 CHECK_OR_FALSE(ifm_node->dim(3).known());
489
490 mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
491 CHECK_OR_FALSE(mean_of_ifm);
492 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
493
494 // continue search from add_as_variance
495 CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance));
496 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
497 // TODO Support regarding broadcast
498 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
499
500 mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x());
501 CHECK_OR_FALSE(mean_as_variance);
502
503 square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
504 CHECK_OR_FALSE(square);
505
506 sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
507 CHECK_OR_FALSE(sub_2);
508 CHECK_OR_FALSE(ifm == sub_2->x());
509
510 mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y());
511 CHECK_OR_FALSE(mean_of_ifm_2);
512 CHECK_OR_FALSE(ifm == mean_of_ifm_2->input());
513
514 loco::Node *ifm_should_be = nullptr;
515 luci::CircleMean *mean_of_ifm_2_should_be = nullptr;
517 luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
518 CHECK_OR_FALSE(ifm == ifm_should_be);
519 CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be);
520
521 return true;
522}
523
524template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
525{
526 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
527 CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
528
529 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
530 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
531 CHECK_OR_FALSE(ifm_circle->rank() == 4);
532 CHECK_OR_FALSE(ifm_circle->dim(3).known());
533 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
534
535 CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
536
537 CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
538
539 CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
540
541 luci::CircleMul *mul_gamma_should_be = nullptr;
542 luci::CircleMean *mean_of_ifm_should_be = nullptr;
543
544 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
545 CHECK_OR_FALSE(mul_as_scaled_mean);
546 CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
547 .with_commutative_args_of(mul_as_scaled_mean));
548 CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
549 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
550
551 _matched = true;
552 return true;
553}
554
555template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_2>()
556{
557 CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
558 CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
559
560 sub = dynamic_cast<luci::CircleSub *>(div->x());
561 CHECK_OR_FALSE(sub);
562
563 ifm = sub->x();
564 CHECK_OR_FALSE(ifm);
565
566 luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
567 CHECK_OR_FALSE(ifm_node->rank() == 4);
568 CHECK_OR_FALSE(ifm_node->dim(3).known());
569 uint32_t ifm_channel_depth = ifm_node->dim(3).value();
570
571 mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
572 CHECK_OR_FALSE(mean_of_ifm);
573
574 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
575
576 pow = dynamic_cast<luci::CirclePow *>(div->y());
577 CHECK_OR_FALSE(pow);
578
579 add_as_variance = dynamic_cast<luci::CircleAdd *>(pow->x());
580 CHECK_OR_FALSE(add_as_variance);
581
582 luci::CircleConst *zero_point_five = dynamic_cast<luci::CircleConst *>(pow->y());
583 CHECK_OR_FALSE(zero_point_five);
584 CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
585 // TODO Support regarding broadcast
586 CHECK_OR_FALSE(zero_point_five->size<loco::DataType::FLOAT32>() == 1);
587 CHECK_OR_FALSE(zero_point_five->at<loco::DataType::FLOAT32>(0) == 0.5);
588
590 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
591 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
592 // TODO Support regarding broadcast
593 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
594
595 CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
596
597 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
598 CHECK_OR_FALSE(sqdiff);
599
600 loco::Node *ifm_should_be = nullptr;
601 luci::CircleMean *mean_of_ifm_should_be = nullptr;
603 luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
604 CHECK_OR_FALSE(ifm == ifm_should_be);
605 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
606
607 // Check for channel size
608 CHECK_OR_FALSE(is_1D_float32_const(const_as_gamma, ifm_channel_depth));
609 CHECK_OR_FALSE(is_1D_float32_const(const_as_beta, ifm_channel_depth));
610
611 _matched = true;
612 return true;
613}
614
615template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_3>()
616{
617 CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
618 CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
619 CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
620
621 CHECK_OR_FALSE(condition_common_3_4());
622
623 _matched = true;
624 return true;
625}
626
627luci::CircleConst *make_const_one(loco::Graph *graph, float value)
628{
629 auto const_one = graph->nodes()->create<luci::CircleConst>();
630 const_one->dtype(loco::DataType::FLOAT32);
631 const_one->rank(1);
632 const_one->size<loco::DataType::FLOAT32>(1);
633 const_one->at<loco::DataType::FLOAT32>(0) = value;
634 return const_one;
635}
636
637template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_4>()
638{
639 CHECK_OR_FALSE(div);
640 CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
641
642 CHECK_OR_FALSE(condition_common_3_4());
643
644 assert(const_as_gamma == nullptr);
645 assert(const_as_beta == nullptr);
646 assert(mul_gamma == nullptr);
647 assert(add_as_terminal == nullptr);
648
649 // create 1.0 gamma and 0.0 beta
650 auto graph = div->graph();
651 const_as_gamma = make_const_one(graph, 1.0f);
652 const_as_beta = make_const_one(graph, 0.0f);
653 const_as_gamma->name(div->name() + "/gamma");
654 const_as_beta->name(div->name() + "/beta");
655
656 _matched = true;
657 return true;
658}
659
660template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
661{
662 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
663 CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
664
665 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
666 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
667 CHECK_OR_FALSE(ifm_circle->rank() == 4);
668 CHECK_OR_FALSE(ifm_circle->dim(3).known());
669 uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
670
671 CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
672
673 luci::CircleRsqrt *rsqrt_should_be = nullptr;
674 luci::CircleMean *mean_of_ifm_should_be = nullptr;
675
676 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
677 CHECK_OR_FALSE(mul_as_scaled_mean);
678 CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
679 .with_commutative_args_of(mul_as_scaled_mean));
680 CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
681 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
682
683 // mul_gamma is absent
684 // const_as_gamma assume to be 1.0
685 auto graph = add_as_terminal->graph();
686 const_as_gamma = make_const_one(graph, 1.0f);
687 const_as_gamma->name(add_as_terminal->name() + "/gamma");
688
689 _matched = true;
690 return true;
691}
692
693template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_6>()
694{
695 CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
696 CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
697
698 auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
699 CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
700 CHECK_OR_FALSE(ifm_circle->rank() == 3);
701 CHECK_OR_FALSE((ifm_circle->dim(1).known()));
702
703 add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
704 CHECK_OR_FALSE(add_as_variance);
705
707 luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
708
709 CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
710 // TODO Support regarding broadcast
711 CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
712
713 CHECK_OR_FALSE(is_instance_mean_v2(mean_as_variance));
714
715 sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
716 CHECK_OR_FALSE(sqdiff);
717
718 loco::Node *ifm_should_be = nullptr;
719 CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
720 CHECK_OR_FALSE(ifm == ifm_should_be);
722 CHECK_OR_FALSE(ifm == mean_of_ifm->input());
723
724 // If const_as_beta has shape of '1 x chennel x (1 or input last dimension)'
725 uint32_t input_channel = ifm_circle->dim(1).value();
726 uint32_t input_last_dim = ifm_circle->dim(2).value();
727 const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
728 CHECK_OR_FALSE(const_as_beta);
729 CHECK_OR_FALSE(const_as_beta->rank() == 3);
731 const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel &&
732 (const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim));
733
734 luci::CircleRsqrt *rsqrt_should_be = nullptr;
735 luci::CircleMean *mean_of_ifm_should_be = nullptr;
736
737 mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
738 CHECK_OR_FALSE(mul_as_scaled_mean);
739 CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
740 .with_commutative_args_of(mul_as_scaled_mean));
741 CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
742 CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
743
744 // mul_gamma is absent
745 // const_as_gamma assume to be 1.0
746 auto graph = add_as_terminal->graph();
747 const_as_gamma = make_const_one(graph, 1.0f);
748 const_as_gamma->name(add_as_terminal->name() + "/gamma");
749
750 _matched = true;
751 return true;
752}
753
754bool InstanceNormPattern::matched()
755{
756 if (_matched)
757 return true;
758
759 // Check order is DFS
760
761 switch (_pv)
762 {
763 case PatternVersion::Version_1:
764 return match<PatternVersion::Version_1>();
765 case PatternVersion::Version_2:
766 return match<PatternVersion::Version_2>();
767 case PatternVersion::Version_3:
768 return match<PatternVersion::Version_3>();
769 case PatternVersion::Version_4:
770 return match<PatternVersion::Version_4>();
771 case PatternVersion::Version_5:
772 return match<PatternVersion::Version_5>();
773 case PatternVersion::Version_6:
774 return match<PatternVersion::Version_6>();
775
776 default:
777 break;
778 }
779
780 throw std::runtime_error("Invalid InstanceNorm PatternVersion.");
781}
782
783#undef CHECK_OR_FALSE
784
801class FuseInstanceNorm final
802{
803public:
804 FuseInstanceNorm(const InstanceNormPattern &p) : _p(p) {}
805
806public:
807 void apply(void);
808
809private:
810 template <InstanceNormPattern::PatternVersion> void apply(void);
811
812private:
813 void reshape_gamma_beta(void);
814 luci::CircleInstanceNorm *create_inst_norm(loco::Graph *graph);
815
816private:
817 const InstanceNormPattern &_p;
818};
819
820void FuseInstanceNorm::reshape_gamma_beta()
821{
822 // Version 1 and 3 need to reshape
823 {
824 _p.const_as_gamma->rank(1);
825 _p.const_as_gamma->dim(0).set(_p.const_as_gamma->size<loco::DataType::FLOAT32>());
826 _p.const_as_beta->rank(1);
827 _p.const_as_beta->dim(0).set(_p.const_as_beta->size<loco::DataType::FLOAT32>());
828
829 _p.const_as_gamma->shape_status(luci::ShapeStatus::UNDEFINED);
830 _p.const_as_beta->shape_status(luci::ShapeStatus::UNDEFINED);
831 }
832}
833
834luci::CircleInstanceNorm *FuseInstanceNorm::create_inst_norm(loco::Graph *graph)
835{
836 // Make Instance Norm to replace
837 auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>();
838 instance_norm->input(_p.ifm);
839 instance_norm->gamma(_p.const_as_gamma);
840 instance_norm->beta(_p.const_as_beta);
841 float epsilon = _p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
842 instance_norm->epsilon(epsilon);
843 if (_p.add_as_terminal != nullptr)
844 {
845 instance_norm->fusedActivationFunction(_p.add_as_terminal->fusedActivationFunction());
846 // NOTE unique name should be assigned in export
847 instance_norm->name("FusedInstanceNorm/" + _p.add_as_terminal->name());
848 }
849 else
850 {
851 // VERSION_4
852 assert(_p.div != nullptr);
853 instance_norm->fusedActivationFunction(_p.div->fusedActivationFunction());
854 instance_norm->name("FusedInstanceNorm/" + _p.div->name());
855 }
856
857 return instance_norm;
858}
859
860template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_1>()
861{
862 auto graph = _p.add_as_terminal->graph();
863
864 reshape_gamma_beta();
865
866 auto instance_norm = create_inst_norm(graph);
867
868 // set origin
869 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
870 luci::get_origin(_p.mean_of_ifm),
871 luci::get_origin(_p.sqdiff),
872 luci::get_origin(_p.mean_as_variance),
873 luci::get_origin(_p.add_as_variance),
874 luci::get_origin(_p.rsqrt),
875 luci::get_origin(_p.mul_gamma),
876 luci::get_origin(_p.mul_as_scaled_ifm),
877 luci::get_origin(_p.mul_as_scaled_mean),
878 luci::get_origin(_p.sub),
879 luci::get_origin(_p.add_as_terminal)};
880
881 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
882
883 replace(_p.add_as_terminal).with(instance_norm);
884}
885
886template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_2>()
887{
888 auto graph = _p.add_as_terminal->graph();
889
890 auto instance_norm = create_inst_norm(graph);
891
892 // set origin
893 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
894 luci::get_origin(_p.mean_of_ifm),
895 luci::get_origin(_p.sqdiff),
896 luci::get_origin(_p.mean_as_variance),
897 luci::get_origin(_p.add_as_variance),
898 luci::get_origin(_p.pow),
899 luci::get_origin(_p.sub),
900 luci::get_origin(_p.div),
901 luci::get_origin(_p.mul_gamma),
902 luci::get_origin(_p.add_as_terminal)};
903
904 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
905
906 replace(_p.add_as_terminal).with(instance_norm);
907}
908
909template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_3>()
910{
911 auto graph = _p.add_as_terminal->graph();
912
913 reshape_gamma_beta();
914
915 auto instance_norm = create_inst_norm(graph);
916
917 // set origin
918 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
919 luci::get_origin(_p.mean_of_ifm),
920 luci::get_origin(_p.sub),
921 luci::get_origin(_p.mean_of_ifm_2),
922 luci::get_origin(_p.sub_2),
923 luci::get_origin(_p.square),
924 luci::get_origin(_p.mean_as_variance),
925 luci::get_origin(_p.sqrt),
926 luci::get_origin(_p.add_as_variance),
927 luci::get_origin(_p.div),
928 luci::get_origin(_p.mul_gamma),
929 luci::get_origin(_p.add_as_terminal)};
930
931 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
932
933 replace(_p.add_as_terminal).with(instance_norm);
934}
935
936template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_4>()
937{
938 auto graph = _p.div->graph();
939
940 auto instance_norm = create_inst_norm(graph);
941
942 // set origin
943 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
944 luci::get_origin(_p.mean_of_ifm),
945 luci::get_origin(_p.sub),
946 luci::get_origin(_p.mean_of_ifm_2),
947 luci::get_origin(_p.sub_2),
948 luci::get_origin(_p.square),
949 luci::get_origin(_p.mean_as_variance),
950 luci::get_origin(_p.sqrt),
951 luci::get_origin(_p.add_as_variance),
952 luci::get_origin(_p.div)};
953
954 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
955
956 replace(_p.div).with(instance_norm);
957}
958
959template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
960{
961 auto graph = _p.add_as_terminal->graph();
962
963 reshape_gamma_beta();
964
965 auto instance_norm = create_inst_norm(graph);
966
967 // set origin
968 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
969 luci::get_origin(_p.mean_of_ifm),
970 luci::get_origin(_p.sqdiff),
971 luci::get_origin(_p.mean_as_variance),
972 luci::get_origin(_p.add_as_variance),
973 luci::get_origin(_p.rsqrt),
974 luci::get_origin(_p.mul_as_scaled_ifm),
975 luci::get_origin(_p.mul_as_scaled_mean),
976 luci::get_origin(_p.sub),
977 luci::get_origin(_p.add_as_terminal)};
978
979 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
980
981 replace(_p.add_as_terminal).with(instance_norm);
982}
983
984template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_6>()
985{
986 auto graph = _p.add_as_terminal->graph();
987
988 reshape_gamma_beta();
989
990 auto instance_norm = create_inst_norm(graph);
991
992 // set origin
993 std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
994 luci::get_origin(_p.mean_of_ifm),
995 luci::get_origin(_p.sqdiff),
996 luci::get_origin(_p.mean_as_variance),
997 luci::get_origin(_p.add_as_variance),
998 luci::get_origin(_p.rsqrt),
999 luci::get_origin(_p.mul_as_scaled_ifm),
1000 luci::get_origin(_p.mul_as_scaled_mean),
1001 luci::get_origin(_p.sub),
1002 luci::get_origin(_p.add_as_terminal)};
1003
1004 luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
1005
1006 replace(_p.add_as_terminal).with(instance_norm);
1007}
1008
1009void FuseInstanceNorm::apply()
1010{
1011 assert(_p.matched());
1012
1013 switch (_p.version())
1014 {
1015 case InstanceNormPattern::PatternVersion::Version_1:
1016 apply<InstanceNormPattern::PatternVersion::Version_1>();
1017 break;
1018 case InstanceNormPattern::PatternVersion::Version_2:
1019 apply<InstanceNormPattern::PatternVersion::Version_2>();
1020 break;
1021 case InstanceNormPattern::PatternVersion::Version_3:
1022 apply<InstanceNormPattern::PatternVersion::Version_3>();
1023 break;
1024 case InstanceNormPattern::PatternVersion::Version_4:
1025 apply<InstanceNormPattern::PatternVersion::Version_4>();
1026 break;
1027 case InstanceNormPattern::PatternVersion::Version_5:
1028 apply<InstanceNormPattern::PatternVersion::Version_5>();
1029 break;
1030 case InstanceNormPattern::PatternVersion::Version_6:
1031 apply<InstanceNormPattern::PatternVersion::Version_6>();
1032 break;
1033
1034 default:
1035 break;
1036 }
1037}
1038
1039} // namespace
1040
1041namespace
1042{
1043
1044class PostFusion final
1045{
1046public:
1047 PostFusion(luci::CircleInstanceNorm *inst_norm) : _inst_norm(inst_norm) {}
1048
1049private:
1050 uint32_t input_channel(void);
1051
1052 luci::CircleConst *match_const_channel(luci::CircleConst *, uint32_t);
1053 bool match_const_gamma_channel(void);
1054 bool match_const_beta_channel(void);
1055
1056public:
1057 bool process(void);
1058
1059private:
1060 luci::CircleInstanceNorm *_inst_norm = nullptr;
1061};
1062
1066uint32_t PostFusion::input_channel(void)
1067{
1068 auto input = dynamic_cast<luci::CircleNode *>(_inst_norm->input());
1069 if (input == nullptr)
1070 return 0;
1071 if (input->shape_status() != luci::ShapeStatus::VALID)
1072 return 0;
1073
1074 auto input_rank = input->rank();
1075 if (input_rank < 1)
1076 return 0;
1077
1078 if (input_rank == 3)
1079 {
1080 // use dim 1
1081 return input->dim(1).value();
1082 }
1083 // assume channel-last
1084 return input->dim(input_rank - 1).value();
1085}
1086
1090luci::CircleConst *PostFusion::match_const_channel(luci::CircleConst *input_const, uint32_t C)
1091{
1092 luci::CircleConst *new_input_const = nullptr;
1093
1094 auto input_chn = input_const->dim(0).value();
1095 if (input_chn == 1 && input_chn != C)
1096 {
1097 float value = input_const->at<loco::DataType::FLOAT32>(0);
1098 auto clone = luci::clone_node(input_const, input_const->graph());
1099
1100 new_input_const = loco::must_cast<luci::CircleConst *>(clone);
1101 new_input_const->rank(1);
1102 new_input_const->dim(0).set(C);
1103 new_input_const->size<loco::DataType::FLOAT32>(C);
1104 for (uint32_t c = 0; c < C; ++c)
1105 new_input_const->at<loco::DataType::FLOAT32>(c) = value;
1106 }
1107
1108 return new_input_const;
1109}
1110
1114bool PostFusion::match_const_gamma_channel(void)
1115{
1116 auto const_as_gamma = dynamic_cast<luci::CircleConst *>(_inst_norm->gamma());
1117 if (const_as_gamma == nullptr)
1118 return false;
1119
1120 auto C = input_channel();
1121 if (C == 0)
1122 return false;
1123
1124 auto new_const_as_gamma = match_const_channel(const_as_gamma, C);
1125 if (new_const_as_gamma == nullptr)
1126 return false;
1127
1128 _inst_norm->gamma(new_const_as_gamma);
1129
1130 return true;
1131}
1132
1136bool PostFusion::match_const_beta_channel(void)
1137{
1138 auto const_as_beta = dynamic_cast<luci::CircleConst *>(_inst_norm->beta());
1139 if (const_as_beta == nullptr)
1140 return false;
1141
1142 auto C = input_channel();
1143 if (C == 0)
1144 return false;
1145
1146 auto new_const_as_beta = match_const_channel(const_as_beta, C);
1147 if (new_const_as_beta == nullptr)
1148 return false;
1149
1150 _inst_norm->beta(new_const_as_beta);
1151
1152 return true;
1153}
1154
1155bool PostFusion::process(void)
1156{
1157 bool changed = false;
1158
1159 if (match_const_gamma_channel())
1160 changed = true;
1161 if (match_const_beta_channel())
1162 changed = true;
1163
1164 return changed;
1165}
1166
1167} // namespace
1168
1169namespace
1170{
1171
1172bool is_add_input_mul_const(luci::CircleAdd *add)
1173{
1174 luci::CircleMul *p_mul = nullptr;
1175 luci::CircleConst *p_const = nullptr;
1176
1177 return luci::fill(&p_mul, &p_const).with_commutative_args_of(add);
1178}
1179
1180bool is_add_input_mul_sub3d(luci::CircleAdd *add)
1181{
1182 luci::CircleMul *p_mul = nullptr;
1183 luci::CircleSub *p_sub = nullptr;
1184
1185 if (!luci::fill(&p_mul, &p_sub).with_commutative_args_of(add))
1186 return false;
1187
1188 auto sub = dynamic_cast<luci::CircleSub *>(add->y());
1189 if (sub == nullptr)
1190 return false;
1191
1192 auto const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
1193 if (const_as_beta == nullptr || const_as_beta->rank() != 3)
1194 return false;
1195
1196 return true;
1197}
1198
1199bool fuse_instance_norm(luci::CircleAdd *add)
1200{
1201 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1;
1202
1203 if (is_add_input_mul_const(add))
1204 pv = InstanceNormPattern::PatternVersion::Version_2;
1205 else if (is_add_input_mul_sub3d(add))
1206 pv = InstanceNormPattern::PatternVersion::Version_6;
1207
1208 InstanceNormPattern pattern(add, pv);
1209 if (pattern.matched())
1210 {
1211 FuseInstanceNorm fuse(pattern);
1212 fuse.apply();
1213 return true;
1214 }
1215
1216 if (pv == InstanceNormPattern::PatternVersion::Version_1)
1217 {
1218 // if Version_1 failed, try with Version_5
1219 pv = InstanceNormPattern::PatternVersion::Version_5;
1220 InstanceNormPattern pattern(add, pv);
1221 if (pattern.matched())
1222 {
1223 FuseInstanceNorm fuse(pattern);
1224 fuse.apply();
1225 return true;
1226 }
1227 }
1228 else if (pv == InstanceNormPattern::PatternVersion::Version_2)
1229 {
1230 // if Version_2 failed, try with Version_3
1231 pv = InstanceNormPattern::PatternVersion::Version_3;
1232 InstanceNormPattern pattern(add, pv);
1233 if (pattern.matched())
1234 {
1235 FuseInstanceNorm fuse(pattern);
1236 fuse.apply();
1237 return true;
1238 }
1239 }
1240
1241 return false;
1242}
1243
1244bool fuse_instance_norm(luci::CircleDiv *div)
1245{
1246 InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_4;
1247
1248 InstanceNormPattern pattern(div, pv);
1249 if (pattern.matched())
1250 {
1251 FuseInstanceNorm fuse(pattern);
1252 fuse.apply();
1253 return true;
1254 }
1255
1256 return false;
1257}
1258
1259bool post_fusion(luci::CircleInstanceNorm *inst_norm)
1260{
1261 PostFusion postfusion(inst_norm);
1262
1263 return postfusion.process();
1264}
1265
1266} // namespace
1267
1268namespace luci
1269{
1270
1272{
1273 bool changed = false;
1274
1275 // Check Version_1, Version_2, Version_3, Version_5, Version_6
1276 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1277 {
1278 auto add = dynamic_cast<luci::CircleAdd *>(node);
1279 if (not add)
1280 continue;
1281
1282 if (fuse_instance_norm(add))
1283 changed = true;
1284 }
1285
1286 // Check Version_4(from DIV) if MUL-ADD pattern is not found
1287 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1288 {
1289 auto div = dynamic_cast<luci::CircleDiv *>(node);
1290 if (not div)
1291 continue;
1292
1293 if (fuse_instance_norm(div))
1294 changed = true;
1295 }
1296
1297 // Post processing of FuseInstanceNorm
1298 for (auto node : loco::active_nodes(loco::output_nodes(g)))
1299 {
1300 auto inst_norm = dynamic_cast<luci::CircleInstanceNorm *>(node);
1301 if (not inst_norm)
1302 continue;
1303
1304 if (post_fusion(inst_norm))
1305 changed = true;
1306 }
1307
1308 return changed;
1309}
1310
1311} // namespace luci
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:54
Graph * graph(void)
Definition Node.h:70
void with(Node *into) const
Definition Node.cpp:66
ADD in Circle.
Definition CircleAdd.h:34
Class to build tensor data.
Definition CircleConst.h:35
const loco::DataTypeImpl< DT >::Type & at(uint32_t n) const
uint32_t size(void) const
DIV in Circle.
Definition CircleDiv.h:37
INSTANCE_NORM in Circle.
loco::Node * input(void) const
MEAN in Circle.
Definition CircleMean.h:32
bool keep_dims(void) const
Definition CircleMean.h:41
loco::Node * input(void) const
Definition CircleMean.h:34
loco::Node * reduction_indices(void) const
Definition CircleMean.h:37
MUL in Circle.
Definition CircleMul.h:34
POW in Circle.
Definition CirclePow.h:32
RESHAPE in Circle.
RSQRT in Circle.
Definition CircleRsqrt.h:32
SQRT in Circle.
Definition CircleSqrt.h:32
SQUARE in Circle.
SQUARED_DIFFERENCE in Circle.
SUB in Circle.
Definition CircleSub.h:34
#define CHECK_OR_FALSE(condition)
bool is_instance_mean_v1(luci::CircleMean *mean)
bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
bool is_instance_mean_v2(luci::CircleMean *mean)
bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
C
Definition infer.py:52
ShapeInferenceSession apply(ShapeInferenceRule *r)
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
std::shared_ptr< CircleNodeOrigin > composite_origin(const std::initializer_list< std::shared_ptr< CircleNodeOrigin > > origins)
NodeFiller< ARG_TYPE_1, ARG_TYPE_2 > fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
Definition NodeFiller.h:72
CircleNode * clone_node(const CircleNode *node, loco::Graph *graph)
Return a new cloned CircleNode object with same attributes value of node to graph.
luci::CircleConst * clone(luci::CircleConst *node)
Return cloned object of CircleConst node.
T square(T value)
Definition Loss.h:37
bool run(loco::Graph *g) final
Run the pass.