ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::UnidirectionalSequenceLSTM Class Reference

#include <UnidirectionalSequenceLSTM.h>

Collaboration diagram for luci_interpreter::kernels::UnidirectionalSequenceLSTM:

Public Member Functions

 UnidirectionalSequenceLSTM (const Tensor *input, const Tensor *input_to_input_weights, const Tensor *input_to_forget_weights, const Tensor *input_to_cell_weights, const Tensor *input_to_output_weights, const Tensor *recurrent_to_input_weights, const Tensor *recurrent_to_forget_weights, const Tensor *recurrent_to_cell_weights, const Tensor *recurrent_to_output_weights, const Tensor *cell_to_input_weights, const Tensor *cell_to_forget_weights, const Tensor *cell_to_output_weights, const Tensor *input_gate_bias, const Tensor *forget_gate_bias, const Tensor *cell_gate_bias, const Tensor *output_gate_bias, const Tensor *projection_weights, const Tensor *projection_bias, const Tensor *output_state, const Tensor *cell_state, const Tensor *input_layer_norm_coefficients, const Tensor *forget_layer_norm_coefficients, const Tensor *cell_layer_norm_coefficients, const Tensor *output_layer_norm_coefficients, Tensor *output, Tensor *scratchpad_1, Tensor *scratchpad_2, Tensor *scratchpad_3, const UnidirectionalSequenceLSTMParams &params)
 
const Tensorinput () const
 
const Tensorinput_to_input_weights () const
 
const Tensorinput_to_forget_weights () const
 
const Tensorinput_to_cell_weights () const
 
const Tensorinput_to_output_weights () const
 
const Tensorrecurrent_to_input_weights () const
 
const Tensorrecurrent_to_forget_weights () const
 
const Tensorrecurrent_to_cell_weights () const
 
const Tensorrecurrent_to_output_weights () const
 
const Tensorcell_to_input_weights () const
 
const Tensorcell_to_forget_weights () const
 
const Tensorcell_to_output_weights () const
 
const Tensorinput_gate_bias () const
 
const Tensorforget_gate_bias () const
 
const Tensorcell_gate_bias () const
 
const Tensoroutput_gate_bias () const
 
const Tensorprojection_weights () const
 
const Tensorprojection_bias () const
 
const Tensoroutput_state () const
 
const Tensorcell_state () const
 
const Tensorinput_layer_norm_coefficients () const
 
const Tensorforget_layer_norm_coefficients () const
 
const Tensorcell_layer_norm_coefficients () const
 
const Tensoroutput_layer_norm_coefficients () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< UnidirectionalSequenceLSTMParams >
const UnidirectionalSequenceLSTMParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< UnidirectionalSequenceLSTMParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const UnidirectionalSequenceLSTMParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< UnidirectionalSequenceLSTMParams >
const UnidirectionalSequenceLSTMParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 29 of file UnidirectionalSequenceLSTM.h.

Constructor & Destructor Documentation

◆ UnidirectionalSequenceLSTM()

luci_interpreter::kernels::UnidirectionalSequenceLSTM::UnidirectionalSequenceLSTM ( const Tensor input,
const Tensor input_to_input_weights,
const Tensor input_to_forget_weights,
const Tensor input_to_cell_weights,
const Tensor input_to_output_weights,
const Tensor recurrent_to_input_weights,
const Tensor recurrent_to_forget_weights,
const Tensor recurrent_to_cell_weights,
const Tensor recurrent_to_output_weights,
const Tensor cell_to_input_weights,
const Tensor cell_to_forget_weights,
const Tensor cell_to_output_weights,
const Tensor input_gate_bias,
const Tensor forget_gate_bias,
const Tensor cell_gate_bias,
const Tensor output_gate_bias,
const Tensor projection_weights,
const Tensor projection_bias,
const Tensor output_state,
const Tensor cell_state,
const Tensor input_layer_norm_coefficients,
const Tensor forget_layer_norm_coefficients,
const Tensor cell_layer_norm_coefficients,
const Tensor output_layer_norm_coefficients,
Tensor output,
Tensor scratchpad_1,
Tensor scratchpad_2,
Tensor scratchpad_3,
const UnidirectionalSequenceLSTMParams params 
)

Definition at line 440 of file UnidirectionalSequenceLSTM.cpp.

463 : KernelWithParams<UnidirectionalSequenceLSTMParams>(
464 {input,
469
474
478
483
486
489
494 {output, scratchpad_1, scratchpad_2, scratchpad_3}, params)
495{
496 // Do nothing
497}
const UnidirectionalSequenceLSTMParams & params() const
Definition Kernel.h:67

References cell_gate_bias(), cell_layer_norm_coefficients(), cell_state(), cell_to_forget_weights(), cell_to_input_weights(), cell_to_output_weights(), forget_gate_bias(), forget_layer_norm_coefficients(), input(), input_gate_bias(), input_layer_norm_coefficients(), input_to_cell_weights(), input_to_forget_weights(), input_to_input_weights(), input_to_output_weights(), output_gate_bias(), output_layer_norm_coefficients(), output_state(), projection_bias(), projection_weights(), recurrent_to_cell_weights(), recurrent_to_forget_weights(), recurrent_to_input_weights(), and recurrent_to_output_weights().

Member Function Documentation

◆ cell_gate_bias()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::cell_gate_bias ( ) const
inline

Definition at line 75 of file UnidirectionalSequenceLSTM.h.

75{ return _inputs[14]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ cell_layer_norm_coefficients()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::cell_layer_norm_coefficients ( ) const
inline

Definition at line 86 of file UnidirectionalSequenceLSTM.h.

86{ return _inputs[22]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ cell_state()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::cell_state ( ) const
inline

Definition at line 82 of file UnidirectionalSequenceLSTM.h.

82{ return _inputs[19]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and UnidirectionalSequenceLSTM().

◆ cell_to_forget_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::cell_to_forget_weights ( ) const
inline

Definition at line 70 of file UnidirectionalSequenceLSTM.h.

70{ return _inputs[10]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ cell_to_input_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::cell_to_input_weights ( ) const
inline

Definition at line 69 of file UnidirectionalSequenceLSTM.h.

69{ return _inputs[9]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ cell_to_output_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::cell_to_output_weights ( ) const
inline

Definition at line 71 of file UnidirectionalSequenceLSTM.h.

71{ return _inputs[11]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ configure()

void luci_interpreter::kernels::UnidirectionalSequenceLSTM::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 751 of file UnidirectionalSequenceLSTM.cpp.

752{
755
756 // TODO support U8
757 LUCI_INTERPRETER_CHECK(input()->element_type() == loco::DataType::FLOAT32);
758 const bool is_integer = false;
759 const bool use_layer_norm = (forget_layer_norm_coefficients() != nullptr);
760
761 // Inferring batch size, number of outputs and sequence length and
762 // number of cells from the input tensors.
763 const Shape &input_shape = input()->shape();
764 LUCI_INTERPRETER_CHECK(input_shape.num_dims() > 1);
765 const bool time_major = params().time_major;
766 const int n_batch = time_major ? input_shape.dim(1) : input_shape.dim(0);
767 // NOTE as dim(2) is accessed, we need to check this is valid
768 LUCI_INTERPRETER_CHECK(input_shape.num_dims() > 2);
769 const int n_input = input_shape.dim(2);
770
771 const Shape &input_to_output_weights_shape = input_to_output_weights()->shape();
772 const int n_cell = input_to_output_weights_shape.dim(0);
773 LUCI_INTERPRETER_CHECK(input_to_output_weights_shape.num_dims() == 2);
774 LUCI_INTERPRETER_CHECK(input_to_output_weights_shape.dim(1) == n_input);
775
776 const Shape &recurrent_to_output_weights_shape = recurrent_to_output_weights()->shape();
777 LUCI_INTERPRETER_CHECK(recurrent_to_output_weights_shape.num_dims() == 2);
778 LUCI_INTERPRETER_CHECK(recurrent_to_output_weights_shape.dim(0) == n_cell);
779
780 const int n_output = recurrent_to_output_weights_shape.dim(1);
781
782 // Check that input tensor dimensions matches with each other.
783 check_input_tensor_dimensions(n_input, n_output, n_cell, use_layer_norm, is_integer);
784
785 // Check the shape of input state tensors.
786 // These tensor may be 1D or 2D. It's fine as long as the total size is
787 // correct.
788 const Shape &output_state_shape = output_state()->shape();
789 const Shape &cell_state_shape = cell_state()->shape();
790 LUCI_INTERPRETER_CHECK(output_state_shape.num_elements() == n_batch * n_output);
791 LUCI_INTERPRETER_CHECK(cell_state_shape.num_elements() == n_batch * n_cell);
792
793 // Resize the output tensors.
794 Shape output_shape = Shape(input_shape.num_dims());
795 for (int i = 0; i < input_shape.num_dims() - 1; i++)
796 {
797 output_shape.dim(i) = input_shape.dim(i);
798 }
799 output_shape.dim(input_shape.num_dims() - 1) = n_output;
801
802 // TODO import integer
803
804 // output_state and cell_state are variable tensor; use scratchpad.
805 getOutputTensors()[1]->resize(output_state_shape);
806 getOutputTensors()[2]->resize(cell_state_shape);
807
808 const bool use_cifg = (input_to_input_weights() == nullptr);
809 if (use_cifg)
810 getOutputTensors()[3]->resize({n_batch, n_cell * 3});
811 else
812 getOutputTensors()[3]->resize({n_batch, n_cell * 4});
813
814 // hybrid not supported
815 if (input_to_output_weights()->element_type() == loco::DataType::U8 &&
816 input()->element_type() == loco::DataType::FLOAT32)
817 {
818 throw std::runtime_error("Hybrid type is not currently supported");
819 }
820 // TODO support hybrid
821 // TODO support U8
822}
const std::vector< Tensor * > & getOutputTensors() const
Definition Kernel.h:40
const std::vector< const Tensor * > & getInputTensors() const
Definition Kernel.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
int32_t size[5]
Definition Slice.cpp:35
Definition Shape.h:28

References cell_state(), luci_interpreter::Shape::dim(), forget_layer_norm_coefficients(), luci_interpreter::Kernel::getInputTensors(), luci_interpreter::Kernel::getOutputTensors(), input(), input_to_input_weights(), input_to_output_weights(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), luci_interpreter::Shape::num_elements(), output(), output_shape, output_state(), luci_interpreter::KernelWithParams< UnidirectionalSequenceLSTMParams >::params(), recurrent_to_output_weights(), luci_interpreter::Tensor::resize(), luci_interpreter::Tensor::shape(), size, and luci_interpreter::UnidirectionalSequenceLSTMParams::time_major.

◆ execute()

void luci_interpreter::kernels::UnidirectionalSequenceLSTM::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 824 of file UnidirectionalSequenceLSTM.cpp.

825{
826 switch (input()->element_type())
827 {
828 case loco::DataType::FLOAT32:
829 evalFloat();
830 break;
831 default:
832 throw std::runtime_error("Unsupported type");
833 }
834}

References input().

◆ forget_gate_bias()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::forget_gate_bias ( ) const
inline

Definition at line 74 of file UnidirectionalSequenceLSTM.h.

74{ return _inputs[13]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ forget_layer_norm_coefficients()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::forget_layer_norm_coefficients ( ) const
inline

Definition at line 85 of file UnidirectionalSequenceLSTM.h.

85{ return _inputs[21]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and UnidirectionalSequenceLSTM().

◆ input()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::input ( ) const
inline

Definition at line 57 of file UnidirectionalSequenceLSTM.h.

57{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and UnidirectionalSequenceLSTM().

◆ input_gate_bias()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::input_gate_bias ( ) const
inline

Definition at line 73 of file UnidirectionalSequenceLSTM.h.

73{ return _inputs[12]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ input_layer_norm_coefficients()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::input_layer_norm_coefficients ( ) const
inline

Definition at line 84 of file UnidirectionalSequenceLSTM.h.

84{ return _inputs[20]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ input_to_cell_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::input_to_cell_weights ( ) const
inline

Definition at line 61 of file UnidirectionalSequenceLSTM.h.

61{ return _inputs[3]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ input_to_forget_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::input_to_forget_weights ( ) const
inline

Definition at line 60 of file UnidirectionalSequenceLSTM.h.

60{ return _inputs[2]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ input_to_input_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::input_to_input_weights ( ) const
inline

Definition at line 59 of file UnidirectionalSequenceLSTM.h.

59{ return _inputs[1]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and UnidirectionalSequenceLSTM().

◆ input_to_output_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::input_to_output_weights ( ) const
inline

Definition at line 62 of file UnidirectionalSequenceLSTM.h.

62{ return _inputs[4]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and UnidirectionalSequenceLSTM().

◆ output()

Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::output ( ) const
inline

Definition at line 89 of file UnidirectionalSequenceLSTM.h.

89{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().

◆ output_gate_bias()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::output_gate_bias ( ) const
inline

Definition at line 76 of file UnidirectionalSequenceLSTM.h.

76{ return _inputs[15]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ output_layer_norm_coefficients()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::output_layer_norm_coefficients ( ) const
inline

Definition at line 87 of file UnidirectionalSequenceLSTM.h.

87{ return _inputs[23]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ output_state()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::output_state ( ) const
inline

Definition at line 81 of file UnidirectionalSequenceLSTM.h.

81{ return _inputs[18]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and UnidirectionalSequenceLSTM().

◆ projection_bias()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::projection_bias ( ) const
inline

Definition at line 79 of file UnidirectionalSequenceLSTM.h.

79{ return _inputs[17]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ projection_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::projection_weights ( ) const
inline

Definition at line 78 of file UnidirectionalSequenceLSTM.h.

78{ return _inputs[16]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ recurrent_to_cell_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::recurrent_to_cell_weights ( ) const
inline

Definition at line 66 of file UnidirectionalSequenceLSTM.h.

66{ return _inputs[7]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ recurrent_to_forget_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::recurrent_to_forget_weights ( ) const
inline

Definition at line 65 of file UnidirectionalSequenceLSTM.h.

65{ return _inputs[6]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ recurrent_to_input_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::recurrent_to_input_weights ( ) const
inline

Definition at line 64 of file UnidirectionalSequenceLSTM.h.

64{ return _inputs[5]; }

References luci_interpreter::Kernel::_inputs.

Referenced by UnidirectionalSequenceLSTM().

◆ recurrent_to_output_weights()

const Tensor * luci_interpreter::kernels::UnidirectionalSequenceLSTM::recurrent_to_output_weights ( ) const
inline

Definition at line 67 of file UnidirectionalSequenceLSTM.h.

67{ return _inputs[8]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and UnidirectionalSequenceLSTM().


The documentation for this class was generated from the following files: