456{
457
458
459 const auto &operands = _graph.
operands();
461 if (operands.at(output_index).info().isDynamic())
462 return;
463
464 const auto scratch_buffer_index{
466 const auto output_state_out_index{
468 const auto cell_state_out_index{
470
472 const auto input_to_input_weights_index{
474 const auto input_to_forget_weights_index{
476 const auto input_to_cell_weights_index{
478 const auto input_to_output_weights_index{
480 const auto recurrent_to_input_weights_index{
482 const auto recurrent_to_forget_weights_index{
484 const auto recurrent_to_cell_weights_index{
486 const auto recurrent_to_output_weights_index{
488 const auto cell_to_input_weights_index{
490 const auto cell_to_forget_weights_index{
492 const auto cell_to_output_weights_index{
494 const auto input_gate_bias_index{
496 const auto forget_gate_bias_index{
499 const auto output_gate_bias_index{
501 const auto projection_weights_index{
503 const auto projection_bias_index{
505 const auto output_state_in_index{
508
509 OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
510 for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
511 {
512 OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
513 operands.at(output_index).shape().dim(i));
514 }
515 OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
516 operands.at(output_index).shape().rank() == 3) &&
517 (operands.at(input_index).shape().rank() == 2 ||
518 operands.at(input_index).shape().rank() == 3) &&
519 (!operands.exist(input_to_input_weights_index) ||
520 operands.at(input_to_input_weights_index).shape().rank() == 2) &&
521 operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
522 operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
523 operands.at(input_to_output_weights_index).shape().rank() == 2 &&
524 (!operands.exist(recurrent_to_input_weights_index) ||
525 operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
526 operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
527 operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
528 operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
529 (!operands.exist(projection_weights_index) ||
530 operands.at(projection_weights_index).shape().rank() == 2) &&
531 operands.at(output_state_in_index).shape().rank() == 2 &&
532 operands.at(cell_state_in_index).shape().rank() == 2);
533
534 OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
535 operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
536 (!operands.exist(cell_to_forget_weights_index) ||
537 operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
538 (!operands.exist(cell_to_output_weights_index) ||
539 operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
540 (!operands.exist(input_gate_bias_index) ||
541 operands.at(input_gate_bias_index).shape().rank() == 1) &&
542 operands.at(forget_gate_bias_index).shape().rank() == 1 &&
543 operands.at(cell_bias_index).shape().rank() == 1 &&
544 operands.at(output_gate_bias_index).shape().rank() == 1 &&
545 (!operands.exist(projection_bias_index) ||
546 operands.at(projection_bias_index).shape().rank() == 1));
547
548
549 OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
550 (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
551 operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
552 (!operands.exist(recurrent_to_input_weights_index) ||
553 (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
554 operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
555 (!operands.exist(input_gate_bias_index) ||
556 operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
557 (!operands.exist(cell_to_input_weights_index) ||
558 operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
559 ((operands.exist(input_to_input_weights_index) &&
560 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
561 operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
562 (operands.exist(recurrent_to_input_weights_index) &&
563 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
564 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
565 (operands.exist(input_gate_bias_index) &&
566 operands.at(input_gate_bias_index).shape().dim(0) != 0)));
567
568
569 OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
570 operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
571 (!operands.exist(cell_to_output_weights_index) ||
572 operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
573 ((operands.exist(cell_to_forget_weights_index) &&
574 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
575 (operands.exist(cell_to_output_weights_index) &&
576 operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
577
578 bool has_input_to_input_weights =
579 operands.exist(input_to_input_weights_index) &&
580 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
581 operands.at(input_to_input_weights_index).shape().dim(1) != 0);
582 bool has_recurrent_to_input_weights =
583 operands.exist(recurrent_to_input_weights_index) &&
584 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
585 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
586 bool has_input_gate_bias =
587 operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
588 bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
589 operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
590 bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
591 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
592 bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
593 operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
594 bool has_projection_weights = operands.exist(projection_weights_index) &&
595 (operands.at(projection_weights_index).shape().dim(0) != 0 &&
596 operands.at(projection_weights_index).shape().dim(1) != 0);
597 bool has_projection_bias =
598 operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
599
600
601
602
603 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
604
605
606
607
608 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
609
610
611 bool has_projection_param = has_projection_weights;
612
613 const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
614 ? operands.at(input_index).shape().dim(1)
615 : operands.at(input_index).shape().dim(0);
616 OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
617 batch_size == operands.at(cell_state_in_index).shape().dim(0));
618
619 const auto input_size =
620 operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
621 OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
622 input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
623 input_size == operands.at(input_to_output_weights_index).shape().dim(1));
624
625 const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
626 OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
627 num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
628 num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
629 num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
630 num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
631 num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
632 num_units == operands.at(cell_bias_index).shape().dim(0) &&
633 num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
634 num_units == operands.at(cell_state_in_index).shape().dim(1));
635
636 const auto output_size =
637 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
638 OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
639 output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
640 output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
641 output_size == operands.at(output_state_in_index).shape().dim(1));
642
643 if (has_cifg_param)
644 {
645 OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
647 num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
648 num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
649 ((operands.exist(cell_to_input_weights_index) &&
650 num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
651 (!operands.exist(cell_to_input_weights_index) ||
652 operands.at(cell_to_input_weights_index).shape().dim(0) == 0) ) &&
653 num_units == operands.at(input_gate_bias_index).shape().dim(0));
654 OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
655 OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
656 has_input_gate_bias);
657 if (has_cell_to_input_weights)
658 {
659
661 }
662 if (operands.exist(scratch_buffer_index))
663 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
664 }
665 else
666 {
667 if (operands.exist(scratch_buffer_index))
668 OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
669 }
670
671 if (has_peephole_param)
672 {
673 OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
674 num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
675 (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
676 operands.at(cell_to_input_weights_index).shape().dim(0) == 0 ));
677 }
678
679 if (has_projection_param)
680 {
681 OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
682 OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
683 if (has_projection_bias)
684 {
685 OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
686 }
687 }
688
689 if (operands.exist(scratch_buffer_index))
690 {
691 OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
692 OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
693 }
694
695 if (operands.exist(output_state_out_index))
696 {
697 OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
698 OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
699 OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
700 }
701
702 if (operands.exist(cell_state_out_index))
703 {
704 OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
705 OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
706 OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
707 }
708}
@ RECURRENT_TO_CELL_WEIGHTS
@ RECURRENT_TO_FORGET_WEIGHTS
@ RECURRENT_TO_INPUT_WEIGHTS
@ INPUT_TO_FORGET_WEIGHTS
@ RECURRENT_TO_OUTPUT_WEIGHTS
@ INPUT_TO_OUTPUT_WEIGHTS