34constexpr uint32_t inputTensorIdx = 0;
35constexpr uint32_t hiddenHiddenTensorIdx = 1;
36constexpr uint32_t hiddenHiddenBiasTensorIdx = 2;
37constexpr uint32_t hiddenInputTensorIdx = 3;
38constexpr uint32_t hiddenInputBiasTensorIdx = 4;
39constexpr uint32_t stateTensorIdx = 5;
51 uint16_t op_index =
args.kernel_index;
54 runtime_kernel.
readKernel(op_index, runtime_context);
56 const circle::Tensor *
input = runtime_kernel.
inputs[inputTensorIdx];
57 const circle::Tensor *weight_input = runtime_kernel.
inputs[hiddenInputTensorIdx];
58 const circle::Tensor *weight_hidden = runtime_kernel.
inputs[hiddenHiddenTensorIdx];
61 assert(input !=
nullptr);
62 assert(output !=
nullptr);
67 status = runtime_kernel.
getDataFromStorage(op_index, forward_storage, runtime_context);
71 uint8_t *weight_input_data = runtime_kernel.
inputs_data[hiddenInputTensorIdx];
72 uint8_t *weight_hidden_data = runtime_kernel.
inputs_data[hiddenHiddenTensorIdx];
73 uint8_t *bias_input_data = runtime_kernel.
inputs_data[hiddenInputBiasTensorIdx];
74 uint8_t *bias_hidden_data = runtime_kernel.
inputs_data[hiddenHiddenBiasTensorIdx];
75 uint8_t *intermediate_buffer = runtime_kernel.
inputs_data[stateTensorIdx];
77 assert(input_data !=
nullptr);
78 assert(weight_input_data !=
nullptr);
79 assert(weight_hidden_data !=
nullptr);
80 assert(intermediate_buffer !=
nullptr);
83 status = runtime_kernel.
getDataFromStorage(op_index, backward_storage, runtime_context);
85 uint8_t *weight_input_grad_data = runtime_kernel.
inputs_data[hiddenInputTensorIdx];
86 uint8_t *weight_hidden_grad_data = runtime_kernel.
inputs_data[hiddenHiddenTensorIdx];
87 uint8_t *bias_input_grad_data = runtime_kernel.
inputs_data[hiddenInputBiasTensorIdx];
88 uint8_t *bias_hidden_grad_data = runtime_kernel.
inputs_data[hiddenHiddenBiasTensorIdx];
89 uint8_t *state_grad_data = runtime_kernel.
inputs_data[stateTensorIdx];
90 uint8_t *input_grad_data = runtime_kernel.
inputs_data[inputTensorIdx];
93 assert(output_grad_data !=
nullptr);
94 assert(weight_input_grad_data !=
nullptr);
95 assert(weight_hidden_grad_data !=
nullptr);
96 assert(state_grad_data !=
nullptr);
106 output_shape_fc.setDim(0, 1);
107 output_shape_fc.setDim(1, weight_hidden_shape.dims(0));
110 uint8_t *left_fc_output_grad_buffer;
111 uint8_t *right_fc_output_grad_buffer;
113 assert(weight_hidden_shape.dims(0) == weight_input_shape.dims(0));
124 assert(left_fc_output_grad_buffer !=
nullptr and right_fc_output_grad_buffer !=
nullptr);
127 if (
input->type() != circle::TensorType_FLOAT32)
132 core::utils::castInputData<float>(weight_input_data),
133 core::utils::castOutputData<float>(weight_input_grad_data),
134 core::utils::castInputData<float>(weight_hidden_data),
135 core::utils::castOutputData<float>(weight_hidden_grad_data),
136 core::utils::castInputData<float>(bias_input_data),
137 core::utils::castOutputData<float>(bias_input_grad_data),
138 core::utils::castInputData<float>(bias_hidden_data),
139 core::utils::castOutputData<float>(bias_hidden_grad_data),
140 core::utils::castInputData<float>(input_data),
141 core::utils::castOutputData<float>(input_grad_data),
142 core::utils::castOutputData<float>(state_grad_data), input_shape,
143 output_shape, weight_input_shape, weight_hidden_shape, output_shape_fc,
144 core::utils::castOutputData<float>(intermediate_buffer),
145 core::utils::castOutputData<float>(left_fc_output_grad_buffer),
146 core::utils::castOutputData<float>(right_fc_output_grad_buffer));
OMStatus removeTensorFromTensorIndexToData(uint16_t tensor_index)
uint8_t * outputs_data[maxOutputSize]
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
uint8_t * inputs_data[maxInputSize]
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
int32_t inputs_index[maxInputSize]
const circle::Tensor * inputs[maxInputSize]
const luci_interpreter::RuntimeShape output_shape
constexpr uint32_t outputTensorIdx
OMDataType
"scalar" value type
OMStatus GRUWeightGrads(const float *output_grad_data, const float *weight_input_data, float *weight_input_grad_data, const float *weight_hidden_data, float *weight_hidden_grad_data, const float *bias_input_data, float *bias_input_grad_data, const float *bias_hidden_data, float *bias_hidden_grad_data, const float *input_data, float *input_grad_data, float *state_grad_data, const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, const core::OMRuntimeShape &weight_input_shape, const core::OMRuntimeShape &weight_hidden_shape, const core::OMRuntimeShape &output_shape_fc, float *intermediate_buffer, float *left_fc_output_grad_buffer, float *right_fc_output_grad_buffer)
static OMStatus deallocateMemory(uint8_t *data)
static OMStatus allocateMemory(uint32_t size, uint8_t **data)