34void UpdateLstmCellFloat(
int n_batch,
int n_cell,
float *cell_state,
const float *input_gate,
35 float *forget_gate,
const float *cell_gate,
bool use_cifg,
float clip)
37 tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state, n_batch * n_cell, cell_state);
44 float *scratch = forget_gate;
45 tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
46 tensor_utils::VectorVectorCwiseProductAccumulate(cell_gate, scratch, n_batch * n_cell,
51 tensor_utils::VectorVectorCwiseProductAccumulate(cell_gate, input_gate, n_batch * n_cell,
56 tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
60void CalculateLstmOutputFloat(
int n_batch,
int n_cell,
int n_output,
const float *cell_state,
61 const float *output_gate, TfLiteFusedActivation activation,
62 const float *projection_weights,
const float *projection_bias,
63 const float proj_clip,
float *output_state,
float *scratch)
65 tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell, activation, scratch);
66 tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell, scratch);
68 const bool use_projection = (projection_weights !=
nullptr);
69 const bool use_projection_bias = (projection_bias !=
nullptr);
73 if (use_projection_bias)
75 tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch, output_state);
79 std::fill_n(output_state, n_batch * n_output, 0.0f);
81 tensor_utils::MatrixBatchVectorMultiplyAccumulate(projection_weights, n_output, n_cell, scratch,
82 n_batch, output_state);
85 tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
90 std::copy_n(scratch, n_batch * n_output, output_state);
94inline void CalculateLstmGateFloat(
const float *input,
const float *input_to_gate_weights,
95 const float *aux_input,
const float *aux_input_to_gate_weights,
96 const float *output_state,
97 const float *recurrent_to_gate_weights,
const float *cell_state,
98 const float *cell_to_gate_weights,
99 const float *layer_norm_coefficients,
const float *gate_bias,
100 const int n_batch,
const int n_input,
const int n_aux_input,
101 const int n_output,
const int n_cell,
102 const TfLiteFusedActivation activation,
float *gate,
103 const bool is_input_all_zeros,
const bool is_aux_input_all_zeros)
105 const bool use_peephole = (cell_to_gate_weights !=
nullptr);
106 const bool use_layer_norm = (layer_norm_coefficients !=
nullptr);
112 std::fill_n(gate, n_cell * n_batch, 0.0f);
116 tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
120 if (!is_input_all_zeros)
122 tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_gate_weights, n_cell, n_input, input,
127 if (!is_aux_input_all_zeros)
129 tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, n_cell,
130 n_aux_input, aux_input, n_batch, gate);
133 tensor_utils::MatrixBatchVectorMultiplyAccumulate(recurrent_to_gate_weights, n_cell, n_output,
134 output_state, n_batch, gate);
138 tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_gate_weights, n_cell, cell_state,
144 tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
145 tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, gate, n_batch,
147 tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
150 tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation, gate);
153inline void LstmStepFloat(
154 const float *input_ptr,
const float *input_to_input_weights_ptr,
155 const float *input_to_forget_weights_ptr,
const float *input_to_cell_weights_ptr,
156 const float *input_to_output_weights_ptr,
const float *aux_input_ptr,
157 const float *aux_input_to_input_weights_ptr,
const float *aux_input_to_forget_weights_ptr,
158 const float *aux_input_to_cell_weights_ptr,
const float *aux_input_to_output_weights_ptr,
159 const float *recurrent_to_input_weights_ptr,
const float *recurrent_to_forget_weights_ptr,
160 const float *recurrent_to_cell_weights_ptr,
const float *recurrent_to_output_weights_ptr,
161 const float *cell_to_input_weights_ptr,
const float *cell_to_forget_weights_ptr,
162 const float *cell_to_output_weights_ptr,
const float *input_layer_norm_coefficients_ptr,
163 const float *forget_layer_norm_coefficients_ptr,
const float *cell_layer_norm_coefficients_ptr,
164 const float *output_layer_norm_coefficients_ptr,
const float *input_gate_bias_ptr,
165 const float *forget_gate_bias_ptr,
const float *cell_gate_bias_ptr,
166 const float *output_gate_bias_ptr,
const float *projection_weights_ptr,
167 const float *projection_bias_ptr,
const TfLiteLSTMParams *params,
int n_batch,
int n_cell,
168 int n_input,
int n_aux_input,
int n_output,
int output_batch_leading_dim,
float *output_state_ptr,
169 float *cell_state_ptr,
float *scratch0,
float *scratch1,
float *scratch2,
float *scratch3,
174 const bool use_cifg = (input_to_input_weights_ptr ==
nullptr);
177 float *input_gate_scratch = scratch0;
178 float *forget_gate_scratch = scratch1;
179 float *cell_gate_scratch = scratch2;
180 float *output_gate_scratch = scratch3;
183 const bool is_input_all_zeros = tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
184 const bool is_aux_input_all_zeros =
185 (aux_input_ptr ==
nullptr || tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
189 CalculateLstmGateFloat(input_ptr, input_to_input_weights_ptr, aux_input_ptr,
190 aux_input_to_input_weights_ptr, output_state_ptr,
191 recurrent_to_input_weights_ptr, cell_state_ptr,
192 cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
193 input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
194 kTfLiteActSigmoid, input_gate_scratch, is_input_all_zeros,
195 is_aux_input_all_zeros);
198 CalculateLstmGateFloat(input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
199 aux_input_to_forget_weights_ptr, output_state_ptr,
200 recurrent_to_forget_weights_ptr, cell_state_ptr,
201 cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
202 forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
203 kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
204 is_aux_input_all_zeros);
206 CalculateLstmGateFloat(
207 input_ptr, input_to_cell_weights_ptr, aux_input_ptr, aux_input_to_cell_weights_ptr,
208 output_state_ptr, recurrent_to_cell_weights_ptr,
nullptr,
209 nullptr, cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, n_batch,
210 n_input, n_aux_input, n_output, n_cell, params->activation, cell_gate_scratch,
211 is_input_all_zeros, is_aux_input_all_zeros);
213 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
214 cell_gate_scratch, use_cifg, params->cell_clip);
216 CalculateLstmGateFloat(input_ptr, input_to_output_weights_ptr, aux_input_ptr,
217 aux_input_to_output_weights_ptr, output_state_ptr,
218 recurrent_to_output_weights_ptr, cell_state_ptr,
219 cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
220 output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
221 kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
222 is_aux_input_all_zeros);
224 CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
225 params->activation, projection_weights_ptr, projection_bias_ptr,
226 params->proj_clip, output_state_ptr, scratch2);
229 for (
int b = 0; b < n_batch; b++)
231 std::copy_n(output_state_ptr + b * n_output, n_output,
232 output_ptr + b * output_batch_leading_dim);
240 const Tensor *input_to_input_weights,
const Tensor *input_to_forget_weights,
241 const Tensor *input_to_cell_weights,
const Tensor *input_to_output_weights,
243 const Tensor *recurrent_to_input_weights,
const Tensor *recurrent_to_forget_weights,
244 const Tensor *recurrent_to_cell_weights,
const Tensor *recurrent_to_output_weights,
246 const Tensor *cell_to_input_weights,
const Tensor *cell_to_forget_weights,
247 const Tensor *cell_to_output_weights,
249 const Tensor *input_layer_norm_coefficients,
250 const Tensor *forget_layer_norm_coefficients,
251 const Tensor *cell_layer_norm_coefficients,
252 const Tensor *output_layer_norm_coefficients,
254 const Tensor *aux_input,
const Tensor *aux_input_to_input_weights,
255 const Tensor *aux_input_to_forget_weights,
const Tensor *aux_input_to_cell_weights,
256 const Tensor *aux_input_to_output_weights,
258 const Tensor *input_gate_bias,
const Tensor *forget_gate_bias,
259 const Tensor *cell_gate_bias,
const Tensor *output_gate_bias,
261 const Tensor *projection_weights,
const Tensor *projection_bias,
262 const TfLiteLSTMParams *params,
264 bool forward_sequence,
bool time_major,
int output_offset,
268 const Shape &input_shape = input->shape();
270 int max_time, n_batch;
273 max_time = (time_major) ? input_shape.
dim(0) : input_shape.
dim(1);
274 n_batch = (time_major) ? input_shape.
dim(1) : input_shape.
dim(0);
279 n_batch = input_shape.
dim(0);
281 const int n_input = input_shape.
dim(input_shape.
num_dims() - 1);
283 int aux_input_temp = 0;
286 const Shape &aux_input_shape = aux_input->
shape();
287 aux_input_temp = aux_input_shape.
dim(aux_input_shape.
num_dims() - 1);
289 const int aux_input_size = aux_input_temp;
292 const Shape &input_to_output_weights_shape = input_to_output_weights->
shape();
293 const Shape &recurrent_to_output_weights_shape = recurrent_to_output_weights->
shape();
294 const int n_cell = input_to_output_weights_shape.
dim(0);
295 const int n_output = recurrent_to_output_weights_shape.
dim(1);
299 const bool use_cifg = (input_to_input_weights ==
nullptr);
302 float *scratch_buffer_ptr = getTensorData<float>(scratch_buffer);
303 float *input_gate_scratch =
nullptr;
304 float *cell_gate_scratch =
nullptr;
305 float *forget_gate_scratch =
nullptr;
306 float *output_gate_scratch =
nullptr;
309 cell_gate_scratch = scratch_buffer_ptr;
310 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
311 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
315 input_gate_scratch = scratch_buffer_ptr;
316 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
317 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
318 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
326 const int input_step = n_batch * n_input;
327 const int output_step = n_batch * output_batch_leading_dim;
328 for (
int t = 0; t < max_time; t++)
332 const int t_rel = forward_sequence ? t : max_time - t - 1;
333 const float *input_ptr = getTensorData<float>(input) + t_rel * input_step;
334 const float *aux_input_ptr =
nullptr;
337 aux_input_ptr = getTensorData<float>(aux_input) + t_rel * input_step;
339 float *output_ptr = getTensorData<float>(output) + t_rel * output_step + output_offset;
342 input_ptr, getTensorData<float>(input_to_input_weights),
343 getTensorData<float>(input_to_forget_weights), getTensorData<float>(input_to_cell_weights),
344 getTensorData<float>(input_to_output_weights), aux_input_ptr,
345 getTensorData<float>(aux_input_to_input_weights),
346 getTensorData<float>(aux_input_to_forget_weights),
347 getTensorData<float>(aux_input_to_cell_weights),
348 getTensorData<float>(aux_input_to_output_weights),
349 getTensorData<float>(recurrent_to_input_weights),
350 getTensorData<float>(recurrent_to_forget_weights),
351 getTensorData<float>(recurrent_to_cell_weights),
352 getTensorData<float>(recurrent_to_output_weights),
353 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
354 getTensorData<float>(cell_to_output_weights),
355 getTensorData<float>(input_layer_norm_coefficients),
356 getTensorData<float>(forget_layer_norm_coefficients),
357 getTensorData<float>(cell_layer_norm_coefficients),
358 getTensorData<float>(output_layer_norm_coefficients), getTensorData<float>(input_gate_bias),
359 getTensorData<float>(forget_gate_bias), getTensorData<float>(cell_gate_bias),
360 getTensorData<float>(output_gate_bias), getTensorData<float>(projection_weights),
361 getTensorData<float>(projection_bias), params, n_batch, n_cell, n_input, aux_input_size,
362 n_output, output_batch_leading_dim, getTensorData<float>(output_state),
363 getTensorData<float>(cell_state), input_gate_scratch, forget_gate_scratch,
364 cell_gate_scratch, output_gate_scratch, output_ptr);
369 for (
int b = 0; b < n_batch; b++)
371 const int input_step = n_input;
372 const int output_step = output_batch_leading_dim;
373 for (
int t = 0; t < max_time; t++)
377 const int t_rel = forward_sequence ? t : max_time - t - 1;
378 const int time_offset = b * max_time + t_rel;
379 const float *input_ptr = getTensorData<float>(input) + time_offset * input_step;
380 const float *aux_input_ptr =
nullptr;
383 aux_input_ptr = getTensorData<float>(aux_input) + time_offset * input_step;
386 getTensorData<float>(output) + time_offset * output_step + output_offset;
389 float *output_state_ptr = getTensorData<float>(output_state) + b * output_batch_leading_dim;
390 float *cell_state_ptr = getTensorData<float>(cell_state) + b * n_cell;
392 float *input_gate_scratch_ptr =
393 input_gate_scratch ? input_gate_scratch + b * n_cell :
nullptr;
394 float *forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
395 float *cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
396 float *output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
399 input_ptr, getTensorData<float>(input_to_input_weights),
400 getTensorData<float>(input_to_forget_weights),
401 getTensorData<float>(input_to_cell_weights),
402 getTensorData<float>(input_to_output_weights), aux_input_ptr,
403 getTensorData<float>(aux_input_to_input_weights),
404 getTensorData<float>(aux_input_to_forget_weights),
405 getTensorData<float>(aux_input_to_cell_weights),
406 getTensorData<float>(aux_input_to_output_weights),
407 getTensorData<float>(recurrent_to_input_weights),
408 getTensorData<float>(recurrent_to_forget_weights),
409 getTensorData<float>(recurrent_to_cell_weights),
410 getTensorData<float>(recurrent_to_output_weights),
411 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
412 getTensorData<float>(cell_to_output_weights),
413 getTensorData<float>(input_layer_norm_coefficients),
414 getTensorData<float>(forget_layer_norm_coefficients),
415 getTensorData<float>(cell_layer_norm_coefficients),
416 getTensorData<float>(output_layer_norm_coefficients),
417 getTensorData<float>(input_gate_bias), getTensorData<float>(forget_gate_bias),
418 getTensorData<float>(cell_gate_bias), getTensorData<float>(output_gate_bias),
419 getTensorData<float>(projection_weights), getTensorData<float>(projection_bias), params,
420 1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
421 output_state_ptr, cell_state_ptr, input_gate_scratch_ptr, forget_gate_scratch_ptr,
422 cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);