267{
268 const Shape &input_shape = input->shape();
270 int max_time, n_batch;
272 {
273 max_time = (time_major) ? input_shape.
dim(0) : input_shape.
dim(1);
274 n_batch = (time_major) ? input_shape.
dim(1) : input_shape.
dim(0);
275 }
276 else
277 {
278 max_time = 1;
279 n_batch = input_shape.
dim(0);
280 }
281 const int n_input = input_shape.
dim(input_shape.
num_dims() - 1);
282
283 int aux_input_temp = 0;
284 if (aux_input)
285 {
286 const Shape &aux_input_shape = aux_input->
shape();
287 aux_input_temp = aux_input_shape.
dim(aux_input_shape.
num_dims() - 1);
288 }
289 const int aux_input_size = aux_input_temp;
290
291
292 const Shape &input_to_output_weights_shape = input_to_output_weights->
shape();
293 const Shape &recurrent_to_output_weights_shape = recurrent_to_output_weights->
shape();
294 const int n_cell = input_to_output_weights_shape.dim(0);
295 const int n_output = recurrent_to_output_weights_shape.dim(1);
296
297
298
299 const bool use_cifg = (input_to_input_weights == nullptr);
300
301
302 float *scratch_buffer_ptr = getTensorData<float>(scratch_buffer);
303 float *input_gate_scratch = nullptr;
304 float *cell_gate_scratch = nullptr;
305 float *forget_gate_scratch = nullptr;
306 float *output_gate_scratch = nullptr;
307 if (use_cifg)
308 {
309 cell_gate_scratch = scratch_buffer_ptr;
310 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
311 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
312 }
313 else
314 {
315 input_gate_scratch = scratch_buffer_ptr;
316 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
317 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
318 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
319 }
320
323 if (time_major)
324 {
325
326 const int input_step = n_batch * n_input;
327 const int output_step = n_batch * output_batch_leading_dim;
328 for (int t = 0; t < max_time; t++)
329 {
330
331
332 const int t_rel = forward_sequence ? t : max_time - t - 1;
333 const float *input_ptr = getTensorData<float>(input) + t_rel * input_step;
334 const float *aux_input_ptr = nullptr;
335 if (aux_input)
336 {
337 aux_input_ptr = getTensorData<float>(aux_input) + t_rel * input_step;
338 }
339 float *output_ptr = getTensorData<float>(output) + t_rel * output_step + output_offset;
340
342 input_ptr, getTensorData<float>(input_to_input_weights),
343 getTensorData<float>(input_to_forget_weights), getTensorData<float>(input_to_cell_weights),
344 getTensorData<float>(input_to_output_weights), aux_input_ptr,
345 getTensorData<float>(aux_input_to_input_weights),
346 getTensorData<float>(aux_input_to_forget_weights),
347 getTensorData<float>(aux_input_to_cell_weights),
348 getTensorData<float>(aux_input_to_output_weights),
349 getTensorData<float>(recurrent_to_input_weights),
350 getTensorData<float>(recurrent_to_forget_weights),
351 getTensorData<float>(recurrent_to_cell_weights),
352 getTensorData<float>(recurrent_to_output_weights),
353 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
354 getTensorData<float>(cell_to_output_weights),
355 getTensorData<float>(input_layer_norm_coefficients),
356 getTensorData<float>(forget_layer_norm_coefficients),
357 getTensorData<float>(cell_layer_norm_coefficients),
358 getTensorData<float>(output_layer_norm_coefficients), getTensorData<float>(input_gate_bias),
359 getTensorData<float>(forget_gate_bias), getTensorData<float>(cell_gate_bias),
360 getTensorData<float>(output_gate_bias), getTensorData<float>(projection_weights),
361 getTensorData<float>(projection_bias), params, n_batch, n_cell, n_input, aux_input_size,
362 n_output, output_batch_leading_dim, getTensorData<float>(output_state),
363 getTensorData<float>(cell_state), input_gate_scratch, forget_gate_scratch,
364 cell_gate_scratch, output_gate_scratch, output_ptr);
365 }
366 }
367 else
368 {
369 for (
int b = 0;
b < n_batch;
b++)
370 {
371 const int input_step = n_input;
372 const int output_step = output_batch_leading_dim;
373 for (int t = 0; t < max_time; t++)
374 {
375
376
377 const int t_rel = forward_sequence ? t : max_time - t - 1;
378 const int time_offset =
b * max_time + t_rel;
379 const float *input_ptr = getTensorData<float>(input) + time_offset * input_step;
380 const float *aux_input_ptr = nullptr;
381 if (aux_input)
382 {
383 aux_input_ptr = getTensorData<float>(aux_input) + time_offset * input_step;
384 }
385 float *output_ptr =
386 getTensorData<float>(output) + time_offset * output_step + output_offset;
387
388
389 float *output_state_ptr = getTensorData<float>(output_state) +
b * output_batch_leading_dim;
390 float *cell_state_ptr = getTensorData<float>(cell_state) +
b * n_cell;
391
392 float *input_gate_scratch_ptr =
393 input_gate_scratch ? input_gate_scratch +
b * n_cell :
nullptr;
394 float *forget_gate_scratch_ptr = forget_gate_scratch +
b * n_cell;
395 float *cell_gate_scratch_ptr = cell_gate_scratch +
b * n_cell;
396 float *output_gate_scratch_ptr = output_gate_scratch +
b * n_cell;
397
399 input_ptr, getTensorData<float>(input_to_input_weights),
400 getTensorData<float>(input_to_forget_weights),
401 getTensorData<float>(input_to_cell_weights),
402 getTensorData<float>(input_to_output_weights), aux_input_ptr,
403 getTensorData<float>(aux_input_to_input_weights),
404 getTensorData<float>(aux_input_to_forget_weights),
405 getTensorData<float>(aux_input_to_cell_weights),
406 getTensorData<float>(aux_input_to_output_weights),
407 getTensorData<float>(recurrent_to_input_weights),
408 getTensorData<float>(recurrent_to_forget_weights),
409 getTensorData<float>(recurrent_to_cell_weights),
410 getTensorData<float>(recurrent_to_output_weights),
411 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
412 getTensorData<float>(cell_to_output_weights),
413 getTensorData<float>(input_layer_norm_coefficients),
414 getTensorData<float>(forget_layer_norm_coefficients),
415 getTensorData<float>(cell_layer_norm_coefficients),
416 getTensorData<float>(output_layer_norm_coefficients),
417 getTensorData<float>(input_gate_bias), getTensorData<float>(forget_gate_bias),
418 getTensorData<float>(cell_gate_bias), getTensorData<float>(output_gate_bias),
419 getTensorData<float>(projection_weights), getTensorData<float>(projection_bias), params,
420 1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
421 output_state_ptr, cell_state_ptr, input_gate_scratch_ptr, forget_gate_scratch_ptr,
422 cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);
423 }
424 }
425 }
426}
const Shape & shape() const
const luci_interpreter::RuntimeShape output_shape
void LstmStepFloat(const float *input_ptr, const float *input_to_input_weights_ptr, const float *input_to_forget_weights_ptr, const float *input_to_cell_weights_ptr, const float *input_to_output_weights_ptr, const float *aux_input_ptr, const float *aux_input_to_input_weights_ptr, const float *aux_input_to_forget_weights_ptr, const float *aux_input_to_cell_weights_ptr, const float *aux_input_to_output_weights_ptr, const float *recurrent_to_input_weights_ptr, const float *recurrent_to_forget_weights_ptr, const float *recurrent_to_cell_weights_ptr, const float *recurrent_to_output_weights_ptr, const float *cell_to_input_weights_ptr, const float *cell_to_forget_weights_ptr, const float *cell_to_output_weights_ptr, const float *input_layer_norm_coefficients_ptr, const float *forget_layer_norm_coefficients_ptr, const float *cell_layer_norm_coefficients_ptr, const float *output_layer_norm_coefficients_ptr, const float *input_gate_bias_ptr, const float *forget_gate_bias_ptr, const float *cell_gate_bias_ptr, const float *output_gate_bias_ptr, const float *projection_weights_ptr, const float *projection_bias_ptr, const LSTMParams *params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float *output_state_ptr, float *cell_state_ptr, float *scratch0, float *scratch1, float *scratch2, float *scratch3, float *output_ptr)