63 const float *aux_input,
const float *aux_input_to_gate_weights,
64 const float *output_state,
65 const float *recurrent_to_gate_weights,
const float *cell_state,
66 const float *cell_to_gate_weights,
67 const float *layer_norm_coefficients,
const float *gate_bias,
68 const int n_batch,
const int n_input,
const int n_aux_input,
69 const int n_output,
const int n_cell,
71 const bool is_input_all_zeros,
const bool is_aux_input_all_zeros)
73 const bool use_peephole = (cell_to_gate_weights !=
nullptr);
74 const bool use_layer_norm = (layer_norm_coefficients !=
nullptr);
80 std::fill_n(gate, n_cell * n_batch, 0.0f);
88 if (!is_input_all_zeros)
95 if (!is_aux_input_all_zeros)
286 const float *input_ptr,
const float *input_to_input_weights_ptr,
287 const float *input_to_forget_weights_ptr,
const float *input_to_cell_weights_ptr,
288 const float *input_to_output_weights_ptr,
const float *aux_input_ptr,
289 const float *aux_input_to_input_weights_ptr,
const float *aux_input_to_forget_weights_ptr,
290 const float *aux_input_to_cell_weights_ptr,
const float *aux_input_to_output_weights_ptr,
291 const float *recurrent_to_input_weights_ptr,
const float *recurrent_to_forget_weights_ptr,
292 const float *recurrent_to_cell_weights_ptr,
const float *recurrent_to_output_weights_ptr,
293 const float *cell_to_input_weights_ptr,
const float *cell_to_forget_weights_ptr,
294 const float *cell_to_output_weights_ptr,
const float *input_layer_norm_coefficients_ptr,
295 const float *forget_layer_norm_coefficients_ptr,
const float *cell_layer_norm_coefficients_ptr,
296 const float *output_layer_norm_coefficients_ptr,
const float *input_gate_bias_ptr,
297 const float *forget_gate_bias_ptr,
const float *cell_gate_bias_ptr,
298 const float *output_gate_bias_ptr,
const float *projection_weights_ptr,
299 const float *projection_bias_ptr,
const LSTMParams *params,
int n_batch,
int n_cell,
int n_input,
300 int n_aux_input,
int n_output,
int output_batch_leading_dim,
float *output_state_ptr,
301 float *cell_state_ptr,
float *scratch0,
float *scratch1,
float *scratch2,
float *scratch3,
306 const bool use_cifg = (input_to_input_weights_ptr ==
nullptr);
309 float *input_gate_scratch = scratch0;
310 float *forget_gate_scratch = scratch1;
311 float *cell_gate_scratch = scratch2;
312 float *output_gate_scratch = scratch3;
315 const bool is_input_all_zeros =
IsZeroVector(input_ptr, n_batch * n_input);
316 const bool is_aux_input_all_zeros =
317 (aux_input_ptr ==
nullptr ||
IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
322 aux_input_to_input_weights_ptr, output_state_ptr,
323 recurrent_to_input_weights_ptr, cell_state_ptr,
324 cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
325 input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
327 input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
331 aux_input_to_forget_weights_ptr, output_state_ptr,
332 recurrent_to_forget_weights_ptr, cell_state_ptr,
333 cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
334 forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
336 forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
339 input_ptr, input_to_cell_weights_ptr, aux_input_ptr, aux_input_to_cell_weights_ptr,
340 output_state_ptr, recurrent_to_cell_weights_ptr,
nullptr,
341 nullptr, cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, n_batch,
342 n_input, n_aux_input, n_output, n_cell, params->
activation, cell_gate_scratch,
343 is_input_all_zeros, is_aux_input_all_zeros);
345 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
346 cell_gate_scratch, use_cifg, params->
cell_clip);
349 aux_input_to_output_weights_ptr, output_state_ptr,
350 recurrent_to_output_weights_ptr, cell_state_ptr,
351 cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
352 output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
354 output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
357 params->
activation, projection_weights_ptr, projection_bias_ptr,
358 params->
proj_clip, output_state_ptr, scratch2);
361 for (
int b = 0; b < n_batch; b++)
363 std::copy_n(output_state_ptr + b * n_output, n_output,
364 output_ptr + b * output_batch_leading_dim);
void CalculateLstmGateFloat(const float *input, const float *input_to_gate_weights, const float *aux_input, const float *aux_input_to_gate_weights, const float *output_state, const float *recurrent_to_gate_weights, const float *cell_state, const float *cell_to_gate_weights, const float *layer_norm_coefficients, const float *gate_bias, const int n_batch, const int n_input, const int n_aux_input, const int n_output, const int n_cell, const FusedActivationFunctionType activation, float *gate, const bool is_input_all_zeros, const bool is_aux_input_all_zeros)
void LstmStepFloat(const float *input_ptr, const float *input_to_input_weights_ptr, const float *input_to_forget_weights_ptr, const float *input_to_cell_weights_ptr, const float *input_to_output_weights_ptr, const float *aux_input_ptr, const float *aux_input_to_input_weights_ptr, const float *aux_input_to_forget_weights_ptr, const float *aux_input_to_cell_weights_ptr, const float *aux_input_to_output_weights_ptr, const float *recurrent_to_input_weights_ptr, const float *recurrent_to_forget_weights_ptr, const float *recurrent_to_cell_weights_ptr, const float *recurrent_to_output_weights_ptr, const float *cell_to_input_weights_ptr, const float *cell_to_forget_weights_ptr, const float *cell_to_output_weights_ptr, const float *input_layer_norm_coefficients_ptr, const float *forget_layer_norm_coefficients_ptr, const float *cell_layer_norm_coefficients_ptr, const float *output_layer_norm_coefficients_ptr, const float *input_gate_bias_ptr, const float *forget_gate_bias_ptr, const float *cell_gate_bias_ptr, const float *output_gate_bias_ptr, const float *projection_weights_ptr, const float *projection_bias_ptr, const LSTMParams *params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float *output_state_ptr, float *cell_state_ptr, float *scratch0, float *scratch1, float *scratch2, float *scratch3, float *output_ptr)