620{
621
622
623 const auto &operands = _graph.
operands();
625 if (operands.at(output_index).info().isDynamic())
626 return;
627
628 const auto scratch_buffer_index{
630 const auto output_state_out_index{
632 const auto cell_state_out_index{
634
636 const auto input_to_input_weights_index{
638 const auto input_to_forget_weights_index{
640 const auto input_to_cell_weights_index{
642 const auto input_to_output_weights_index{
644 const auto recurrent_to_input_weights_index{
646 const auto recurrent_to_forget_weights_index{
648 const auto recurrent_to_cell_weights_index{
650 const auto recurrent_to_output_weights_index{
652 const auto cell_to_input_weights_index{
654 const auto cell_to_forget_weights_index{
656 const auto cell_to_output_weights_index{
658 const auto input_gate_bias_index{
660 const auto forget_gate_bias_index{
663 const auto output_gate_bias_index{
665 const auto projection_weights_index{
667 const auto projection_bias_index{
669 const auto output_state_in_index{
672
673 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
674 for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
675 {
676 OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
677 operands.at(output_index).shape().dim(i));
678 }
679 OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
680 operands.at(output_index).shape().rank() == 3) &&
681 (operands.at(input_index).shape().rank() == 2 ||
682 operands.at(input_index).shape().rank() == 3) &&
683 (!operands.exist(input_to_input_weights_index) ||
684 operands.at(input_to_input_weights_index).shape().rank() == 2) &&
685 operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
686 operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
687 operands.at(input_to_output_weights_index).shape().rank() == 2 &&
688 (!operands.exist(recurrent_to_input_weights_index) ||
689 operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
690 operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
691 operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
692 operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
693 (!operands.exist(projection_weights_index) ||
694 operands.at(projection_weights_index).shape().rank() == 2) &&
695 operands.at(output_state_in_index).shape().rank() == 2 &&
696 operands.at(cell_state_in_index).shape().rank() == 2);
697
698 OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
699 operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
700 (!operands.exist(cell_to_forget_weights_index) ||
701 operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
702 (!operands.exist(cell_to_output_weights_index) ||
703 operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
704 (!operands.exist(input_gate_bias_index) ||
705 operands.at(input_gate_bias_index).shape().rank() == 1) &&
706 operands.at(forget_gate_bias_index).shape().rank() == 1 &&
707 operands.at(cell_bias_index).shape().rank() == 1 &&
708 operands.at(output_gate_bias_index).shape().rank() == 1 &&
709 (!operands.exist(projection_bias_index) ||
710 operands.at(projection_bias_index).shape().rank() == 1));
711
712
713 OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
714 (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
715 operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
716 (!operands.exist(recurrent_to_input_weights_index) ||
717 (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
718 operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
719 (!operands.exist(input_gate_bias_index) ||
720 operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
721 (!operands.exist(cell_to_input_weights_index) ||
722 operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
723 ((operands.exist(input_to_input_weights_index) &&
724 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
725 operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
726 (operands.exist(recurrent_to_input_weights_index) &&
727 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
728 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
729 (operands.exist(input_gate_bias_index) &&
730 operands.at(input_gate_bias_index).shape().dim(0) != 0)));
731
732
733 OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
734 operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
735 (!operands.exist(cell_to_output_weights_index) ||
736 operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
737 ((operands.exist(cell_to_forget_weights_index) &&
738 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
739 (operands.exist(cell_to_output_weights_index) &&
740 operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
741
742 bool has_input_to_input_weights =
743 operands.exist(input_to_input_weights_index) &&
744 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
745 operands.at(input_to_input_weights_index).shape().dim(1) != 0);
746 bool has_recurrent_to_input_weights =
747 operands.exist(recurrent_to_input_weights_index) &&
748 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
749 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
750 bool has_input_gate_bias =
751 operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
752 bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
753 operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
754 bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
755 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
756 bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
757 operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
758 bool has_projection_weights = operands.exist(projection_weights_index) &&
759 (operands.at(projection_weights_index).shape().dim(0) != 0 &&
760 operands.at(projection_weights_index).shape().dim(1) != 0);
761 bool has_projection_bias =
762 operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
763
764
765
766
767 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
768
769
770
771
772 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
773
774
775 bool has_projection_param = has_projection_weights;
776
777 const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
778 ? operands.at(input_index).shape().dim(1)
779 : operands.at(input_index).shape().dim(0);
780 OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
781 batch_size == operands.at(cell_state_in_index).shape().dim(0));
782
783 const auto input_size =
784 operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
785 OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
786 input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
787 input_size == operands.at(input_to_output_weights_index).shape().dim(1));
788
789 const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
790 OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
791 num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
792 num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
793 num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
794 num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
795 num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
796 num_units == operands.at(cell_bias_index).shape().dim(0) &&
797 num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
798 num_units == operands.at(cell_state_in_index).shape().dim(1));
799
800 const auto output_size =
801 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
802 OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
803 output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
804 output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
805 output_size == operands.at(output_state_in_index).shape().dim(1));
806
807 if (has_cifg_param)
808 {
809 OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
811 num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
812 num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
813 ((operands.exist(cell_to_input_weights_index) &&
814 num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
815 (!operands.exist(cell_to_input_weights_index) ||
816 operands.at(cell_to_input_weights_index).shape().dim(0) == 0) ) &&
817 num_units == operands.at(input_gate_bias_index).shape().dim(0));
818 OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
819 OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
820 has_input_gate_bias);
821 if (has_cell_to_input_weights)
822 {
823
825 }
826 if (operands.exist(scratch_buffer_index))
827 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
828 }
829 else
830 {
831 if (operands.exist(scratch_buffer_index))
832 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
833 }
834
835 if (has_peephole_param)
836 {
837 OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
838 num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
839 (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
840 operands.at(cell_to_input_weights_index).shape().dim(0) == 0 ));
841 }
842
843 if (has_projection_param)
844 {
845 OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
846 OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
847 if (has_projection_bias)
848 {
849 OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
850 }
851 }
852
853 if (operands.exist(scratch_buffer_index))
854 {
855 OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
856 OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
857 }
858
859 if (operands.exist(output_state_out_index))
860 {
861 OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
862 OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
863 OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
864 }
865
866 if (operands.exist(cell_state_out_index))
867 {
868 OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
869 OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
870 OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
871 }
872}
@ RECURRENT_TO_CELL_WEIGHTS
@ RECURRENT_TO_FORGET_WEIGHTS
@ RECURRENT_TO_INPUT_WEIGHTS
@ INPUT_TO_FORGET_WEIGHTS
@ RECURRENT_TO_OUTPUT_WEIGHTS
@ INPUT_TO_OUTPUT_WEIGHTS