60 const circle::Tensor *input;
61 const circle::Tensor *hidden_hidden;
62 const circle::Tensor *hidden_hidden_bias;
63 const circle::Tensor *hidden_input;
64 const circle::Tensor *hidden_input_bias;
65 const circle::Tensor *state;
67 const circle::Tensor *output;
70 uint8_t *hidden_hidden_data;
71 uint8_t *hidden_hidden_bias_data;
72 uint8_t *hidden_input_data;
73 uint8_t *hidden_input_bias_data;
77 uint16_t state_tensor_index = 0;
82 runtime_kernel.
readKernel(op_index, runtime_context);
84 input = runtime_kernel.
inputs[inputTensorIdx];
85 hidden_hidden = runtime_kernel.
inputs[hiddenHiddenTensorIdx];
86 hidden_hidden_bias = runtime_kernel.
inputs[hiddenHiddenBiasTensorIdx];
87 hidden_input = runtime_kernel.
inputs[hiddenInputTensorIdx];
88 hidden_input_bias = runtime_kernel.
inputs[hiddenInputBiasTensorIdx];
89 state = runtime_kernel.
inputs[stateTensorIdx];
91 output = runtime_kernel.
outputs[outputTensorIdx];
92 assert(input !=
nullptr);
93 assert(hidden_hidden !=
nullptr);
94 assert(hidden_input !=
nullptr);
95 assert(state !=
nullptr);
97 assert(output !=
nullptr);
101 input_data = runtime_kernel.
inputs_data[inputTensorIdx];
102 hidden_hidden_data = runtime_kernel.
inputs_data[hiddenHiddenTensorIdx];
103 hidden_hidden_bias_data = runtime_kernel.
inputs_data[hiddenHiddenBiasTensorIdx];
104 hidden_input_data = runtime_kernel.
inputs_data[hiddenInputTensorIdx];
105 hidden_input_bias_data = runtime_kernel.
inputs_data[hiddenInputBiasTensorIdx];
106 state_data = runtime_kernel.
inputs_data[stateTensorIdx];
108 output_data = runtime_kernel.
outputs_data[outputTensorIdx];
109 assert(input_data !=
nullptr);
110 assert(hidden_hidden_data !=
nullptr);
111 assert(hidden_input_data !=
nullptr);
112 assert(state_data !=
nullptr);
114 assert(output_data !=
nullptr);
116 state_tensor_index = runtime_kernel.
inputs_index[stateTensorIdx];
121 uint8_t *output_hidden_data;
122 uint8_t *output_input_data;
127 &output_hidden_data);
138 const int32_t num_of_intermediate_tensors = 9;
141 assert(size_of_intermediate_tensors > 0);
142 if (size_of_intermediate_tensors == 0)
146 const int32_t output_size = size_of_intermediate_tensors;
155 size_t intermediate_buffer_size = 0;
156 uint8_t *intermediate_buffer =
nullptr;
161 uint32_t num_train_layers =
163 uint32_t last_node_pos = std::min(num_operators, num_train_layers);
164 uint32_t last_train_op_index = num_operators - last_node_pos;
168 intermediate_buffer_size = num_of_intermediate_tensors * size_of_intermediate_tensors;
171 time * intermediate_buffer_size * data_type_size, &intermediate_buffer);
180 switch (input->type())
183 case circle::TensorType_FLOAT32:
186 pal::GRU(core::utils::castInputData<float>(input_data),
187 core::utils::castInputData<float>(hidden_input_data),
188 core::utils::castInputData<float>(hidden_hidden_data),
189 core::utils::castInputData<float>(hidden_input_bias_data),
190 core::utils::castInputData<float>(hidden_hidden_bias_data),
191 core::utils::castInputData<float>(state_data),
192 core::utils::castOutputData<float>(output_data),
193 core::utils::castOutputData<float>(output_input_data),
194 core::utils::castOutputData<float>(output_hidden_data),
197 intermediate_buffer_size, core::utils::castOutputData<float>(intermediate_buffer));
204 assert(
false &&
"Unsupported type.");
OMStatus GRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data, const float *bias_input_data, const float *bias_hidden_data, const float *hidden_state_data, float *output_data, float *output_input_data, float *output_hidden_data, const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, const core::OMRuntimeShape &weight_input_shape, const core::OMRuntimeShape &weight_hidden_shape, const size_t intermediate_buffer_size, float *intermediate_buffer)