37constexpr uint32_t inputTensorIdx = 0;
38constexpr uint32_t hiddenHiddenTensorIdx = 1;
39constexpr uint32_t hiddenHiddenBiasTensorIdx = 2;
40constexpr uint32_t hiddenInputTensorIdx = 3;
41constexpr uint32_t hiddenInputBiasTensorIdx = 4;
42constexpr uint32_t stateTensorIdx = 5;
55 const circle::Tensor *
input;
56 const circle::Tensor *hidden_hidden;
57 const circle::Tensor *hidden_hidden_bias;
58 const circle::Tensor *hidden_input;
59 const circle::Tensor *hidden_input_bias;
60 const circle::Tensor *state;
62 const circle::Tensor *
output;
65 uint8_t *hidden_hidden_data;
66 uint8_t *hidden_hidden_bias_data;
67 uint8_t *hidden_input_data;
68 uint8_t *hidden_input_bias_data;
72 uint16_t state_tensor_index = 0;
77 runtime_kernel.
readKernel(op_index, runtime_context);
80 hidden_hidden = runtime_kernel.
inputs[hiddenHiddenTensorIdx];
81 hidden_hidden_bias = runtime_kernel.
inputs[hiddenHiddenBiasTensorIdx];
82 hidden_input = runtime_kernel.
inputs[hiddenInputTensorIdx];
83 hidden_input_bias = runtime_kernel.
inputs[hiddenInputBiasTensorIdx];
84 state = runtime_kernel.
inputs[stateTensorIdx];
87 assert(input !=
nullptr);
88 assert(hidden_hidden !=
nullptr);
89 assert(hidden_input !=
nullptr);
90 assert(state !=
nullptr);
92 assert(output !=
nullptr);
97 hidden_hidden_data = runtime_kernel.
inputs_data[hiddenHiddenTensorIdx];
98 hidden_hidden_bias_data = runtime_kernel.
inputs_data[hiddenHiddenBiasTensorIdx];
99 hidden_input_data = runtime_kernel.
inputs_data[hiddenInputTensorIdx];
100 hidden_input_bias_data = runtime_kernel.
inputs_data[hiddenInputBiasTensorIdx];
101 state_data = runtime_kernel.
inputs_data[stateTensorIdx];
104 assert(input_data !=
nullptr);
105 assert(hidden_hidden_data !=
nullptr);
106 assert(hidden_input_data !=
nullptr);
107 assert(state_data !=
nullptr);
109 assert(output_data !=
nullptr);
111 state_tensor_index = runtime_kernel.
inputs_index[stateTensorIdx];
116 uint8_t *output_hidden_data;
117 uint8_t *output_input_data;
122 &output_hidden_data);
133 const int32_t num_of_intermediate_tensors = 9;
136 assert(size_of_intermediate_tensors > 0);
137 if (size_of_intermediate_tensors == 0)
141 const int32_t output_size = size_of_intermediate_tensors;
150 size_t intermediate_buffer_size = 0;
151 uint8_t *intermediate_buffer =
nullptr;
156 uint32_t num_train_layers =
158 uint32_t last_node_pos = std::min(num_operators, num_train_layers);
159 uint32_t last_train_op_index = num_operators - last_node_pos;
163 intermediate_buffer_size = num_of_intermediate_tensors * size_of_intermediate_tensors;
166 time * intermediate_buffer_size * data_type_size, &intermediate_buffer);
175 switch (
input->type())
178 case circle::TensorType_FLOAT32:
181 pal::GRU(core::utils::castInputData<float>(input_data),
182 core::utils::castInputData<float>(hidden_input_data),
183 core::utils::castInputData<float>(hidden_hidden_data),
184 core::utils::castInputData<float>(hidden_input_bias_data),
185 core::utils::castInputData<float>(hidden_hidden_bias_data),
186 core::utils::castInputData<float>(state_data),
187 core::utils::castOutputData<float>(output_data),
188 core::utils::castOutputData<float>(output_input_data),
189 core::utils::castOutputData<float>(output_hidden_data),
192 intermediate_buffer_size, core::utils::castOutputData<float>(intermediate_buffer));
199 assert(
false &&
"Unsupported type.");
const reader::CircleOperators * getCircleOperators()
int32_t dims(int i) const
OMStatus saveDataToTensorIndex(uint8_t *data, uint16_t tensor_index)
uint8_t * outputs_data[maxOutputSize]
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
uint8_t * inputs_data[maxInputSize]
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
int32_t inputs_index[maxInputSize]
const circle::Tensor * inputs[maxInputSize]
constexpr uint32_t outputTensorIdx
OMDataType
"scalar" value type
OMStatus GRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data, const float *bias_input_data, const float *bias_hidden_data, const float *hidden_state_data, float *output_data, float *output_input_data, float *output_hidden_data, const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, const core::OMRuntimeShape &weight_input_shape, const core::OMRuntimeShape &weight_hidden_shape, const size_t intermediate_buffer_size, float *intermediate_buffer)
static OMStatus deallocateMemory(uint8_t *data)
static OMStatus allocateMemory(uint32_t size, uint8_t **data)
uint32_t num_train_layers
core::OMRuntimeContext & runtime_context
core::OMRuntimeStorage & runtime_storage