34void UpdateLstmCellFloat(
int n_batch,
int n_cell,
float *cell_state,
const float *input_gate,
35 float *forget_gate,
const float *cell_gate,
bool use_cifg,
float clip)
39#pragma GCC diagnostic ignored "-Wrestrict"
40 tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state, n_batch * n_cell, cell_state);
47 float *scratch = forget_gate;
48 tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
49 tensor_utils::VectorVectorCwiseProductAccumulate(cell_gate, scratch, n_batch * n_cell,
54 tensor_utils::VectorVectorCwiseProductAccumulate(cell_gate, input_gate, n_batch * n_cell,
59 tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
63void CalculateLstmOutputFloat(
int n_batch,
int n_cell,
int n_output,
const float *cell_state,
64 const float *output_gate, TfLiteFusedActivation activation,
65 const float *projection_weights,
const float *projection_bias,
66 const float proj_clip,
float *output_state,
float *scratch)
68 tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell, activation, scratch);
69 tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell, scratch);
71 const bool use_projection = (projection_weights !=
nullptr);
72 const bool use_projection_bias = (projection_bias !=
nullptr);
76 if (use_projection_bias)
78 tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch, output_state);
82 std::fill_n(output_state, n_batch * n_output, 0.0f);
84 tensor_utils::MatrixBatchVectorMultiplyAccumulate(projection_weights, n_output, n_cell, scratch,
85 n_batch, output_state);
88 tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
93 std::copy_n(scratch, n_batch * n_output, output_state);
97inline void CalculateLstmGateFloat(
const float *input,
const float *input_to_gate_weights,
98 const float *aux_input,
const float *aux_input_to_gate_weights,
99 const float *output_state,
100 const float *recurrent_to_gate_weights,
const float *cell_state,
101 const float *cell_to_gate_weights,
102 const float *layer_norm_coefficients,
const float *gate_bias,
103 const int n_batch,
const int n_input,
const int n_aux_input,
104 const int n_output,
const int n_cell,
105 const TfLiteFusedActivation activation,
float *gate,
106 const bool is_input_all_zeros,
const bool is_aux_input_all_zeros)
108 const bool use_peephole = (cell_to_gate_weights !=
nullptr);
109 const bool use_layer_norm = (layer_norm_coefficients !=
nullptr);
115 std::fill_n(gate, n_cell * n_batch, 0.0f);
119 tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
123 if (!is_input_all_zeros)
125 tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_gate_weights, n_cell, n_input, input,
130 if (!is_aux_input_all_zeros)
132 tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, n_cell,
133 n_aux_input, aux_input, n_batch, gate);
136 tensor_utils::MatrixBatchVectorMultiplyAccumulate(recurrent_to_gate_weights, n_cell, n_output,
137 output_state, n_batch, gate);
141 tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_gate_weights, n_cell, cell_state,
147 tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
148 tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, gate, n_batch,
150 tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
153 tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation, gate);
156inline void LstmStepFloat(
157 const float *input_ptr,
const float *input_to_input_weights_ptr,
158 const float *input_to_forget_weights_ptr,
const float *input_to_cell_weights_ptr,
159 const float *input_to_output_weights_ptr,
const float *aux_input_ptr,
160 const float *aux_input_to_input_weights_ptr,
const float *aux_input_to_forget_weights_ptr,
161 const float *aux_input_to_cell_weights_ptr,
const float *aux_input_to_output_weights_ptr,
162 const float *recurrent_to_input_weights_ptr,
const float *recurrent_to_forget_weights_ptr,
163 const float *recurrent_to_cell_weights_ptr,
const float *recurrent_to_output_weights_ptr,
164 const float *cell_to_input_weights_ptr,
const float *cell_to_forget_weights_ptr,
165 const float *cell_to_output_weights_ptr,
const float *input_layer_norm_coefficients_ptr,
166 const float *forget_layer_norm_coefficients_ptr,
const float *cell_layer_norm_coefficients_ptr,
167 const float *output_layer_norm_coefficients_ptr,
const float *input_gate_bias_ptr,
168 const float *forget_gate_bias_ptr,
const float *cell_gate_bias_ptr,
169 const float *output_gate_bias_ptr,
const float *projection_weights_ptr,
170 const float *projection_bias_ptr,
const TfLiteLSTMParams *params,
int n_batch,
int n_cell,
171 int n_input,
int n_aux_input,
int n_output,
int output_batch_leading_dim,
float *output_state_ptr,
172 float *cell_state_ptr,
float *scratch0,
float *scratch1,
float *scratch2,
float *scratch3,
177 const bool use_cifg = (input_to_input_weights_ptr ==
nullptr);
180 float *input_gate_scratch = scratch0;
181 float *forget_gate_scratch = scratch1;
182 float *cell_gate_scratch = scratch2;
183 float *output_gate_scratch = scratch3;
186 const bool is_input_all_zeros = tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
187 const bool is_aux_input_all_zeros =
188 (aux_input_ptr ==
nullptr || tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
192 CalculateLstmGateFloat(input_ptr, input_to_input_weights_ptr, aux_input_ptr,
193 aux_input_to_input_weights_ptr, output_state_ptr,
194 recurrent_to_input_weights_ptr, cell_state_ptr,
195 cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
196 input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
197 kTfLiteActSigmoid, input_gate_scratch, is_input_all_zeros,
198 is_aux_input_all_zeros);
201 CalculateLstmGateFloat(input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
202 aux_input_to_forget_weights_ptr, output_state_ptr,
203 recurrent_to_forget_weights_ptr, cell_state_ptr,
204 cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
205 forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
206 kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
207 is_aux_input_all_zeros);
209 CalculateLstmGateFloat(
210 input_ptr, input_to_cell_weights_ptr, aux_input_ptr, aux_input_to_cell_weights_ptr,
211 output_state_ptr, recurrent_to_cell_weights_ptr,
nullptr,
212 nullptr, cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, n_batch,
213 n_input, n_aux_input, n_output, n_cell, params->activation, cell_gate_scratch,
214 is_input_all_zeros, is_aux_input_all_zeros);
216 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
217 cell_gate_scratch, use_cifg, params->cell_clip);
219 CalculateLstmGateFloat(input_ptr, input_to_output_weights_ptr, aux_input_ptr,
220 aux_input_to_output_weights_ptr, output_state_ptr,
221 recurrent_to_output_weights_ptr, cell_state_ptr,
222 cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
223 output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
224 kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
225 is_aux_input_all_zeros);
227 CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
228 params->activation, projection_weights_ptr, projection_bias_ptr,
229 params->proj_clip, output_state_ptr, scratch2);
232 for (
int b = 0; b < n_batch; b++)
234 std::copy_n(output_state_ptr + b * n_output, n_output,
235 output_ptr + b * output_batch_leading_dim);
243 const Tensor *input_to_input_weights,
const Tensor *input_to_forget_weights,
244 const Tensor *input_to_cell_weights,
const Tensor *input_to_output_weights,
246 const Tensor *recurrent_to_input_weights,
const Tensor *recurrent_to_forget_weights,
247 const Tensor *recurrent_to_cell_weights,
const Tensor *recurrent_to_output_weights,
249 const Tensor *cell_to_input_weights,
const Tensor *cell_to_forget_weights,
250 const Tensor *cell_to_output_weights,
252 const Tensor *input_layer_norm_coefficients,
253 const Tensor *forget_layer_norm_coefficients,
254 const Tensor *cell_layer_norm_coefficients,
255 const Tensor *output_layer_norm_coefficients,
257 const Tensor *aux_input,
const Tensor *aux_input_to_input_weights,
258 const Tensor *aux_input_to_forget_weights,
const Tensor *aux_input_to_cell_weights,
259 const Tensor *aux_input_to_output_weights,
261 const Tensor *input_gate_bias,
const Tensor *forget_gate_bias,
262 const Tensor *cell_gate_bias,
const Tensor *output_gate_bias,
264 const Tensor *projection_weights,
const Tensor *projection_bias,
265 const TfLiteLSTMParams *params,
267 bool forward_sequence,
bool time_major,
int output_offset,
271 const Shape &input_shape = input->shape();
273 int max_time, n_batch;
276 max_time = (time_major) ? input_shape.
dim(0) : input_shape.
dim(1);
277 n_batch = (time_major) ? input_shape.
dim(1) : input_shape.
dim(0);
282 n_batch = input_shape.
dim(0);
284 const int n_input = input_shape.
dim(input_shape.
num_dims() - 1);
286 int aux_input_temp = 0;
289 const Shape &aux_input_shape = aux_input->
shape();
290 aux_input_temp = aux_input_shape.
dim(aux_input_shape.
num_dims() - 1);
292 const int aux_input_size = aux_input_temp;
295 const Shape &input_to_output_weights_shape = input_to_output_weights->
shape();
296 const Shape &recurrent_to_output_weights_shape = recurrent_to_output_weights->
shape();
297 const int n_cell = input_to_output_weights_shape.
dim(0);
298 const int n_output = recurrent_to_output_weights_shape.
dim(1);
302 const bool use_cifg = (input_to_input_weights ==
nullptr);
305 float *scratch_buffer_ptr = getTensorData<float>(scratch_buffer);
306 float *input_gate_scratch =
nullptr;
307 float *cell_gate_scratch =
nullptr;
308 float *forget_gate_scratch =
nullptr;
309 float *output_gate_scratch =
nullptr;
312 cell_gate_scratch = scratch_buffer_ptr;
313 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
314 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
318 input_gate_scratch = scratch_buffer_ptr;
319 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
320 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
321 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
329 const int input_step = n_batch * n_input;
330 const int output_step = n_batch * output_batch_leading_dim;
331 for (
int t = 0; t < max_time; t++)
335 const int t_rel = forward_sequence ? t : max_time - t - 1;
336 const float *input_ptr = getTensorData<float>(input) + t_rel * input_step;
337 const float *aux_input_ptr =
nullptr;
340 aux_input_ptr = getTensorData<float>(aux_input) + t_rel * input_step;
342 float *output_ptr = getTensorData<float>(output) + t_rel * output_step + output_offset;
345 input_ptr, getTensorData<float>(input_to_input_weights),
346 getTensorData<float>(input_to_forget_weights), getTensorData<float>(input_to_cell_weights),
347 getTensorData<float>(input_to_output_weights), aux_input_ptr,
348 getTensorData<float>(aux_input_to_input_weights),
349 getTensorData<float>(aux_input_to_forget_weights),
350 getTensorData<float>(aux_input_to_cell_weights),
351 getTensorData<float>(aux_input_to_output_weights),
352 getTensorData<float>(recurrent_to_input_weights),
353 getTensorData<float>(recurrent_to_forget_weights),
354 getTensorData<float>(recurrent_to_cell_weights),
355 getTensorData<float>(recurrent_to_output_weights),
356 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
357 getTensorData<float>(cell_to_output_weights),
358 getTensorData<float>(input_layer_norm_coefficients),
359 getTensorData<float>(forget_layer_norm_coefficients),
360 getTensorData<float>(cell_layer_norm_coefficients),
361 getTensorData<float>(output_layer_norm_coefficients), getTensorData<float>(input_gate_bias),
362 getTensorData<float>(forget_gate_bias), getTensorData<float>(cell_gate_bias),
363 getTensorData<float>(output_gate_bias), getTensorData<float>(projection_weights),
364 getTensorData<float>(projection_bias), params, n_batch, n_cell, n_input, aux_input_size,
365 n_output, output_batch_leading_dim, getTensorData<float>(output_state),
366 getTensorData<float>(cell_state), input_gate_scratch, forget_gate_scratch,
367 cell_gate_scratch, output_gate_scratch, output_ptr);
372 for (
int b = 0; b < n_batch; b++)
374 const int input_step = n_input;
375 const int output_step = output_batch_leading_dim;
376 for (
int t = 0; t < max_time; t++)
380 const int t_rel = forward_sequence ? t : max_time - t - 1;
381 const int time_offset = b * max_time + t_rel;
382 const float *input_ptr = getTensorData<float>(input) + time_offset * input_step;
383 const float *aux_input_ptr =
nullptr;
386 aux_input_ptr = getTensorData<float>(aux_input) + time_offset * input_step;
389 getTensorData<float>(output) + time_offset * output_step + output_offset;
392 float *output_state_ptr = getTensorData<float>(output_state) + b * output_batch_leading_dim;
393 float *cell_state_ptr = getTensorData<float>(cell_state) + b * n_cell;
395 float *input_gate_scratch_ptr =
396 input_gate_scratch ? input_gate_scratch + b * n_cell :
nullptr;
397 float *forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
398 float *cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
399 float *output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
402 input_ptr, getTensorData<float>(input_to_input_weights),
403 getTensorData<float>(input_to_forget_weights),
404 getTensorData<float>(input_to_cell_weights),
405 getTensorData<float>(input_to_output_weights), aux_input_ptr,
406 getTensorData<float>(aux_input_to_input_weights),
407 getTensorData<float>(aux_input_to_forget_weights),
408 getTensorData<float>(aux_input_to_cell_weights),
409 getTensorData<float>(aux_input_to_output_weights),
410 getTensorData<float>(recurrent_to_input_weights),
411 getTensorData<float>(recurrent_to_forget_weights),
412 getTensorData<float>(recurrent_to_cell_weights),
413 getTensorData<float>(recurrent_to_output_weights),
414 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
415 getTensorData<float>(cell_to_output_weights),
416 getTensorData<float>(input_layer_norm_coefficients),
417 getTensorData<float>(forget_layer_norm_coefficients),
418 getTensorData<float>(cell_layer_norm_coefficients),
419 getTensorData<float>(output_layer_norm_coefficients),
420 getTensorData<float>(input_gate_bias), getTensorData<float>(forget_gate_bias),
421 getTensorData<float>(cell_gate_bias), getTensorData<float>(output_gate_bias),
422 getTensorData<float>(projection_weights), getTensorData<float>(projection_bias), params,
423 1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
424 output_state_ptr, cell_state_ptr, input_gate_scratch_ptr, forget_gate_scratch_ptr,
425 cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);