488{
489
490
491 const auto &operands = _graph.
operands();
493 if (operands.at(output_index).info().isDynamic())
494 return;
495
496 const auto scratch_buffer_index{
498 const auto output_state_out_index{
500 const auto cell_state_out_index{
502
504 const auto input_to_input_weights_index{
506 const auto input_to_forget_weights_index{
508 const auto input_to_cell_weights_index{
510 const auto input_to_output_weights_index{
512 const auto recurrent_to_input_weights_index{
514 const auto recurrent_to_forget_weights_index{
516 const auto recurrent_to_cell_weights_index{
518 const auto recurrent_to_output_weights_index{
520 const auto cell_to_input_weights_index{
522 const auto cell_to_forget_weights_index{
524 const auto cell_to_output_weights_index{
526 const auto input_gate_bias_index{
528 const auto forget_gate_bias_index{
531 const auto output_gate_bias_index{
533 const auto projection_weights_index{
535 const auto projection_bias_index{
537 const auto output_state_in_index{
540
541 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
542 for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
543 {
544 OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
545 operands.at(output_index).shape().dim(i));
546 }
547 OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
548 operands.at(output_index).shape().rank() == 3) &&
549 (operands.at(input_index).shape().rank() == 2 ||
550 operands.at(input_index).shape().rank() == 3) &&
551 (!operands.exist(input_to_input_weights_index) ||
552 operands.at(input_to_input_weights_index).shape().rank() == 2) &&
553 operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
554 operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
555 operands.at(input_to_output_weights_index).shape().rank() == 2 &&
556 (!operands.exist(recurrent_to_input_weights_index) ||
557 operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
558 operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
559 operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
560 operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
561 (!operands.exist(projection_weights_index) ||
562 operands.at(projection_weights_index).shape().rank() == 2) &&
563 operands.at(output_state_in_index).shape().rank() == 2 &&
564 operands.at(cell_state_in_index).shape().rank() == 2);
565
566 OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
567 operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
568 (!operands.exist(cell_to_forget_weights_index) ||
569 operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
570 (!operands.exist(cell_to_output_weights_index) ||
571 operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
572 (!operands.exist(input_gate_bias_index) ||
573 operands.at(input_gate_bias_index).shape().rank() == 1) &&
574 operands.at(forget_gate_bias_index).shape().rank() == 1 &&
575 operands.at(cell_bias_index).shape().rank() == 1 &&
576 operands.at(output_gate_bias_index).shape().rank() == 1 &&
577 (!operands.exist(projection_bias_index) ||
578 operands.at(projection_bias_index).shape().rank() == 1));
579
580
581 OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
582 (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
583 operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
584 (!operands.exist(recurrent_to_input_weights_index) ||
585 (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
586 operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
587 (!operands.exist(input_gate_bias_index) ||
588 operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
589 (!operands.exist(cell_to_input_weights_index) ||
590 operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
591 ((operands.exist(input_to_input_weights_index) &&
592 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
593 operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
594 (operands.exist(recurrent_to_input_weights_index) &&
595 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
596 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
597 (operands.exist(input_gate_bias_index) &&
598 operands.at(input_gate_bias_index).shape().dim(0) != 0)));
599
600
601 OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
602 operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
603 (!operands.exist(cell_to_output_weights_index) ||
604 operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
605 ((operands.exist(cell_to_forget_weights_index) &&
606 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
607 (operands.exist(cell_to_output_weights_index) &&
608 operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
609
610 bool has_input_to_input_weights =
611 operands.exist(input_to_input_weights_index) &&
612 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
613 operands.at(input_to_input_weights_index).shape().dim(1) != 0);
614 bool has_recurrent_to_input_weights =
615 operands.exist(recurrent_to_input_weights_index) &&
616 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
617 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
618 bool has_input_gate_bias =
619 operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
620 bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
621 operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
622 bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
623 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
624 bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
625 operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
626 bool has_projection_weights = operands.exist(projection_weights_index) &&
627 (operands.at(projection_weights_index).shape().dim(0) != 0 &&
628 operands.at(projection_weights_index).shape().dim(1) != 0);
629 bool has_projection_bias =
630 operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
631
632
633
634
635 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
636
637
638
639
640 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
641
642
643 bool has_projection_param = has_projection_weights;
644
645 const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
646 ? operands.at(input_index).shape().dim(1)
647 : operands.at(input_index).shape().dim(0);
648 OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
649 batch_size == operands.at(cell_state_in_index).shape().dim(0));
650
651 const auto input_size =
652 operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
653 OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
654 input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
655 input_size == operands.at(input_to_output_weights_index).shape().dim(1));
656
657 const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
658 OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
659 num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
660 num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
661 num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
662 num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
663 num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
664 num_units == operands.at(cell_bias_index).shape().dim(0) &&
665 num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
666 num_units == operands.at(cell_state_in_index).shape().dim(1));
667
668 const auto output_size =
669 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
670 OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
671 output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
672 output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
673 output_size == operands.at(output_state_in_index).shape().dim(1));
674
675 if (has_cifg_param)
676 {
677 OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
679 num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
680 num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
681 ((operands.exist(cell_to_input_weights_index) &&
682 num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
683 (!operands.exist(cell_to_input_weights_index) ||
684 operands.at(cell_to_input_weights_index).shape().dim(0) == 0) ) &&
685 num_units == operands.at(input_gate_bias_index).shape().dim(0));
686 OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
687 OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
688 has_input_gate_bias);
689 if (has_cell_to_input_weights)
690 {
691
693 }
694 if (operands.exist(scratch_buffer_index))
695 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
696 }
697 else
698 {
699 if (operands.exist(scratch_buffer_index))
700 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
701 }
702
703 if (has_peephole_param)
704 {
705 OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
706 num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
707 (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
708 operands.at(cell_to_input_weights_index).shape().dim(0) == 0 ));
709 }
710
711 if (has_projection_param)
712 {
713 OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
714 OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
715 if (has_projection_bias)
716 {
717 OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
718 }
719 }
720
721 if (operands.exist(scratch_buffer_index))
722 {
723 OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
724 OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
725 }
726
727 if (operands.exist(output_state_out_index))
728 {
729 OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
730 OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
731 OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
732 }
733
734 if (operands.exist(cell_state_out_index))
735 {
736 OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
737 OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
738 OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
739 }
740}
@ RECURRENT_TO_CELL_WEIGHTS
@ RECURRENT_TO_FORGET_WEIGHTS
@ RECURRENT_TO_INPUT_WEIGHTS
@ INPUT_TO_FORGET_WEIGHTS
@ RECURRENT_TO_OUTPUT_WEIGHTS
@ INPUT_TO_OUTPUT_WEIGHTS