623{
624
625
626 const auto &operands = _graph.
operands();
628 if (operands.at(output_index).info().isDynamic())
629 return;
630
631 const auto scratch_buffer_index{
633 const auto output_state_out_index{
635 const auto cell_state_out_index{
637
639 const auto input_to_input_weights_index{
641 const auto input_to_forget_weights_index{
643 const auto input_to_cell_weights_index{
645 const auto input_to_output_weights_index{
647 const auto recurrent_to_input_weights_index{
649 const auto recurrent_to_forget_weights_index{
651 const auto recurrent_to_cell_weights_index{
653 const auto recurrent_to_output_weights_index{
655 const auto cell_to_input_weights_index{
657 const auto cell_to_forget_weights_index{
659 const auto cell_to_output_weights_index{
661 const auto input_gate_bias_index{
663 const auto forget_gate_bias_index{
666 const auto output_gate_bias_index{
668 const auto projection_weights_index{
670 const auto projection_bias_index{
672 const auto output_state_in_index{
675
676 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
677 for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
678 {
679 OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
680 operands.at(output_index).shape().dim(i));
681 }
682 OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
683 operands.at(output_index).shape().rank() == 3) &&
684 (operands.at(input_index).shape().rank() == 2 ||
685 operands.at(input_index).shape().rank() == 3) &&
686 (!operands.exist(input_to_input_weights_index) ||
687 operands.at(input_to_input_weights_index).shape().rank() == 2) &&
688 operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
689 operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
690 operands.at(input_to_output_weights_index).shape().rank() == 2 &&
691 (!operands.exist(recurrent_to_input_weights_index) ||
692 operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
693 operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
694 operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
695 operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
696 (!operands.exist(projection_weights_index) ||
697 operands.at(projection_weights_index).shape().rank() == 2) &&
698 operands.at(output_state_in_index).shape().rank() == 2 &&
699 operands.at(cell_state_in_index).shape().rank() == 2);
700
701 OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
702 operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
703 (!operands.exist(cell_to_forget_weights_index) ||
704 operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
705 (!operands.exist(cell_to_output_weights_index) ||
706 operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
707 (!operands.exist(input_gate_bias_index) ||
708 operands.at(input_gate_bias_index).shape().rank() == 1) &&
709 operands.at(forget_gate_bias_index).shape().rank() == 1 &&
710 operands.at(cell_bias_index).shape().rank() == 1 &&
711 operands.at(output_gate_bias_index).shape().rank() == 1 &&
712 (!operands.exist(projection_bias_index) ||
713 operands.at(projection_bias_index).shape().rank() == 1));
714
715
716 OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
717 (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
718 operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
719 (!operands.exist(recurrent_to_input_weights_index) ||
720 (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
721 operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
722 (!operands.exist(input_gate_bias_index) ||
723 operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
724 (!operands.exist(cell_to_input_weights_index) ||
725 operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
726 ((operands.exist(input_to_input_weights_index) &&
727 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
728 operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
729 (operands.exist(recurrent_to_input_weights_index) &&
730 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
731 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
732 (operands.exist(input_gate_bias_index) &&
733 operands.at(input_gate_bias_index).shape().dim(0) != 0)));
734
735
736 OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
737 operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
738 (!operands.exist(cell_to_output_weights_index) ||
739 operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
740 ((operands.exist(cell_to_forget_weights_index) &&
741 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
742 (operands.exist(cell_to_output_weights_index) &&
743 operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
744
745 bool has_input_to_input_weights =
746 operands.exist(input_to_input_weights_index) &&
747 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
748 operands.at(input_to_input_weights_index).shape().dim(1) != 0);
749 bool has_recurrent_to_input_weights =
750 operands.exist(recurrent_to_input_weights_index) &&
751 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
752 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
753 bool has_input_gate_bias =
754 operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
755 bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
756 operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
757 bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
758 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
759 bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
760 operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
761 bool has_projection_weights = operands.exist(projection_weights_index) &&
762 (operands.at(projection_weights_index).shape().dim(0) != 0 &&
763 operands.at(projection_weights_index).shape().dim(1) != 0);
764 bool has_projection_bias =
765 operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
766
767
768
769
770 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
771
772
773
774
775 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
776
777
778 bool has_projection_param = has_projection_weights;
779
780 const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
781 ? operands.at(input_index).shape().dim(1)
782 : operands.at(input_index).shape().dim(0);
783 OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
784 batch_size == operands.at(cell_state_in_index).shape().dim(0));
785
786 const auto input_size =
787 operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
788 OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
789 input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
790 input_size == operands.at(input_to_output_weights_index).shape().dim(1));
791
792 const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
793 OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
794 num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
795 num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
796 num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
797 num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
798 num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
799 num_units == operands.at(cell_bias_index).shape().dim(0) &&
800 num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
801 num_units == operands.at(cell_state_in_index).shape().dim(1));
802
803 const auto output_size =
804 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
805 OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
806 output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
807 output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
808 output_size == operands.at(output_state_in_index).shape().dim(1));
809
810 if (has_cifg_param)
811 {
812 OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
814 num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
815 num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
816 ((operands.exist(cell_to_input_weights_index) &&
817 num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
818 (!operands.exist(cell_to_input_weights_index) ||
819 operands.at(cell_to_input_weights_index).shape().dim(0) == 0) ) &&
820 num_units == operands.at(input_gate_bias_index).shape().dim(0));
821 OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
822 OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
823 has_input_gate_bias);
824 if (has_cell_to_input_weights)
825 {
826
828 }
829 if (operands.exist(scratch_buffer_index))
830 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
831 }
832 else
833 {
834 if (operands.exist(scratch_buffer_index))
835 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
836 }
837
838 if (has_peephole_param)
839 {
840 OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
841 num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
842 (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
843 operands.at(cell_to_input_weights_index).shape().dim(0) == 0 ));
844 }
845
846 if (has_projection_param)
847 {
848 OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
849 OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
850 if (has_projection_bias)
851 {
852 OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
853 }
854 }
855
856 if (operands.exist(scratch_buffer_index))
857 {
858 OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
859 OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
860 }
861
862 if (operands.exist(output_state_out_index))
863 {
864 OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
865 OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
866 OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
867 }
868
869 if (operands.exist(cell_state_out_index))
870 {
871 OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
872 OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
873 OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
874 }
875}
@ RECURRENT_TO_CELL_WEIGHTS
@ RECURRENT_TO_FORGET_WEIGHTS
@ RECURRENT_TO_INPUT_WEIGHTS
@ INPUT_TO_FORGET_WEIGHTS
@ RECURRENT_TO_OUTPUT_WEIGHTS
@ INPUT_TO_OUTPUT_WEIGHTS