59void calculateGRU(
const float *input_data,
const float *weight_input_data,
60 const float *weight_hidden_data,
const float *bias_input_data,
61 const float *bias_hidden_data,
float *output_data,
62 const tflite::RuntimeShape &input_shape,
const tflite::RuntimeShape &
output_shape,
63 const tflite::RuntimeShape &weight_input_shape,
64 const tflite::RuntimeShape &weight_hidden_shape,
float *output_input_data,
65 float *output_hidden_data,
const tflite::RuntimeShape &output_shape_fc)
67 tflite::FullyConnectedParams op_params{};
69 op_params.float_activation_min = std::numeric_limits<float>::lowest();
70 op_params.float_activation_max = std::numeric_limits<float>::max();
73 tflite::RuntimeShape bias_input_shape{weight_input_shape.Dims(0)};
74 tflite::reference_ops::FullyConnected(op_params,
output_shape, output_data, weight_input_shape,
75 weight_input_data, bias_input_shape, bias_input_data,
76 output_shape_fc, output_input_data);
79 tflite::RuntimeShape bias_hidden_shape{weight_hidden_shape.Dims(0)};
81 tflite::reference_ops::FullyConnected(op_params, input_shape, input_data, weight_hidden_shape,
82 weight_hidden_data, bias_hidden_shape, bias_hidden_data,
83 output_shape_fc, output_hidden_data);
85 int num_elements = output_shape_fc.Dims(1) / 3;
87 float *second_hidden_part = output_hidden_data + num_elements;
88 float *second_input_part = output_input_data + num_elements;
90 float *third_hidden_part = second_hidden_part + num_elements;
91 float *third_input_part = second_input_part + num_elements;
94 for (
int i = 0; i < num_elements; ++i)
96 output_input_data[i] += output_hidden_data[i];
99 Logistic(num_elements, output_input_data, output_input_data);
102 float *most_left_part_final = output_input_data;
103 float *first_part = output_input_data;
104 for (
int i = 0; i < num_elements; ++i)
106 output_data[i] *= most_left_part_final[i];
107 first_part[i] = 1.0f - first_part[i];
111 for (
int i = 0; i < num_elements; ++i)
113 second_hidden_part[i] += second_input_part[i];
116 Logistic(num_elements, second_hidden_part, second_hidden_part);
118 for (
int i = 0; i < num_elements; ++i)
120 second_hidden_part[i] *= third_input_part[i];
121 second_hidden_part[i] += third_hidden_part[i];
124 for (
int i = 0; i < num_elements; ++i)
126 if (second_hidden_part[i] > 19)
128 second_hidden_part[i] = 1;
130 else if (second_hidden_part[i] < -19)
132 second_hidden_part[i] = -1;
136 second_hidden_part[i] = std::tanh(second_hidden_part[i]);
140 for (
int i = 0; i < num_elements; ++i)
142 second_hidden_part[i] *= first_part[i];
143 output_data[i] += second_hidden_part[i];
147void GRU(
const float *input_data,
const float *weight_input_data,
const float *weight_hidden_data,
148 const float *bias_input_data,
const float *bias_hidden_data,
149 const float *hidden_state_data,
float *output_data,
float *output_input_data,
150 float *output_hidden_data,
const tflite::RuntimeShape &input_shape,
151 const tflite::RuntimeShape &
output_shape,
const tflite::RuntimeShape &weight_input_shape,
152 const tflite::RuntimeShape &weight_hidden_shape)
154 const int32_t time = input_shape.Dims(0);
156 tflite::RuntimeShape output_shape_fc(2);
157 output_shape_fc.SetDim(0, 1);
158 output_shape_fc.SetDim(1, weight_hidden_shape.Dims(0));
160 std::memcpy(output_data, hidden_state_data,
output_shape.FlatSize() *
sizeof(
float));
162 for (
int i = 0; i < time; ++i)
164 calculateGRU(input_data, weight_input_data, weight_hidden_data, bias_input_data,
165 bias_hidden_data, output_data, input_shape,
output_shape, weight_input_shape,
166 weight_hidden_shape, output_input_data, output_hidden_data, output_shape_fc);
167 input_data += input_shape.Dims(2);
void GRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data, const float *bias_input_data, const float *bias_hidden_data, const float *hidden_state_data, float *output_data, float *output_input_data, float *output_hidden_data, const tflite::RuntimeShape &input_shape, const tflite::RuntimeShape &output_shape, const tflite::RuntimeShape &weight_input_shape, const tflite::RuntimeShape &weight_hidden_shape)
void calculateGRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data, const float *bias_input_data, const float *bias_hidden_data, float *output_data, const tflite::RuntimeShape &input_shape, const tflite::RuntimeShape &output_shape, const tflite::RuntimeShape &weight_input_shape, const tflite::RuntimeShape &weight_hidden_shape, float *output_input_data, float *output_hidden_data, const tflite::RuntimeShape &output_shape_fc)