270{
271 const Shape &input_shape = input->shape();
273 int max_time, n_batch;
275 {
276 max_time = (time_major) ? input_shape.
dim(0) : input_shape.
dim(1);
277 n_batch = (time_major) ? input_shape.
dim(1) : input_shape.
dim(0);
278 }
279 else
280 {
281 max_time = 1;
282 n_batch = input_shape.
dim(0);
283 }
284 const int n_input = input_shape.
dim(input_shape.
num_dims() - 1);
285
286 int aux_input_temp = 0;
287 if (aux_input)
288 {
289 const Shape &aux_input_shape = aux_input->
shape();
290 aux_input_temp = aux_input_shape.
dim(aux_input_shape.
num_dims() - 1);
291 }
292 const int aux_input_size = aux_input_temp;
293
294
295 const Shape &input_to_output_weights_shape = input_to_output_weights->
shape();
296 const Shape &recurrent_to_output_weights_shape = recurrent_to_output_weights->
shape();
297 const int n_cell = input_to_output_weights_shape.dim(0);
298 const int n_output = recurrent_to_output_weights_shape.dim(1);
299
300
301
302 const bool use_cifg = (input_to_input_weights == nullptr);
303
304
305 float *scratch_buffer_ptr = getTensorData<float>(scratch_buffer);
306 float *input_gate_scratch = nullptr;
307 float *cell_gate_scratch = nullptr;
308 float *forget_gate_scratch = nullptr;
309 float *output_gate_scratch = nullptr;
310 if (use_cifg)
311 {
312 cell_gate_scratch = scratch_buffer_ptr;
313 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
314 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
315 }
316 else
317 {
318 input_gate_scratch = scratch_buffer_ptr;
319 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
320 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
321 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
322 }
323
326 if (time_major)
327 {
328
329 const int input_step = n_batch * n_input;
330 const int output_step = n_batch * output_batch_leading_dim;
331 for (int t = 0; t < max_time; t++)
332 {
333
334
335 const int t_rel = forward_sequence ? t : max_time - t - 1;
336 const float *input_ptr = getTensorData<float>(input) + t_rel * input_step;
337 const float *aux_input_ptr = nullptr;
338 if (aux_input)
339 {
340 aux_input_ptr = getTensorData<float>(aux_input) + t_rel * input_step;
341 }
342 float *output_ptr = getTensorData<float>(output) + t_rel * output_step + output_offset;
343
345 input_ptr, getTensorData<float>(input_to_input_weights),
346 getTensorData<float>(input_to_forget_weights), getTensorData<float>(input_to_cell_weights),
347 getTensorData<float>(input_to_output_weights), aux_input_ptr,
348 getTensorData<float>(aux_input_to_input_weights),
349 getTensorData<float>(aux_input_to_forget_weights),
350 getTensorData<float>(aux_input_to_cell_weights),
351 getTensorData<float>(aux_input_to_output_weights),
352 getTensorData<float>(recurrent_to_input_weights),
353 getTensorData<float>(recurrent_to_forget_weights),
354 getTensorData<float>(recurrent_to_cell_weights),
355 getTensorData<float>(recurrent_to_output_weights),
356 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
357 getTensorData<float>(cell_to_output_weights),
358 getTensorData<float>(input_layer_norm_coefficients),
359 getTensorData<float>(forget_layer_norm_coefficients),
360 getTensorData<float>(cell_layer_norm_coefficients),
361 getTensorData<float>(output_layer_norm_coefficients), getTensorData<float>(input_gate_bias),
362 getTensorData<float>(forget_gate_bias), getTensorData<float>(cell_gate_bias),
363 getTensorData<float>(output_gate_bias), getTensorData<float>(projection_weights),
364 getTensorData<float>(projection_bias), params, n_batch, n_cell, n_input, aux_input_size,
365 n_output, output_batch_leading_dim, getTensorData<float>(output_state),
366 getTensorData<float>(cell_state), input_gate_scratch, forget_gate_scratch,
367 cell_gate_scratch, output_gate_scratch, output_ptr);
368 }
369 }
370 else
371 {
372 for (
int b = 0;
b < n_batch;
b++)
373 {
374 const int input_step = n_input;
375 const int output_step = output_batch_leading_dim;
376 for (int t = 0; t < max_time; t++)
377 {
378
379
380 const int t_rel = forward_sequence ? t : max_time - t - 1;
381 const int time_offset =
b * max_time + t_rel;
382 const float *input_ptr = getTensorData<float>(input) + time_offset * input_step;
383 const float *aux_input_ptr = nullptr;
384 if (aux_input)
385 {
386 aux_input_ptr = getTensorData<float>(aux_input) + time_offset * input_step;
387 }
388 float *output_ptr =
389 getTensorData<float>(output) + time_offset * output_step + output_offset;
390
391
392 float *output_state_ptr = getTensorData<float>(output_state) +
b * output_batch_leading_dim;
393 float *cell_state_ptr = getTensorData<float>(cell_state) +
b * n_cell;
394
395 float *input_gate_scratch_ptr =
396 input_gate_scratch ? input_gate_scratch +
b * n_cell :
nullptr;
397 float *forget_gate_scratch_ptr = forget_gate_scratch +
b * n_cell;
398 float *cell_gate_scratch_ptr = cell_gate_scratch +
b * n_cell;
399 float *output_gate_scratch_ptr = output_gate_scratch +
b * n_cell;
400
402 input_ptr, getTensorData<float>(input_to_input_weights),
403 getTensorData<float>(input_to_forget_weights),
404 getTensorData<float>(input_to_cell_weights),
405 getTensorData<float>(input_to_output_weights), aux_input_ptr,
406 getTensorData<float>(aux_input_to_input_weights),
407 getTensorData<float>(aux_input_to_forget_weights),
408 getTensorData<float>(aux_input_to_cell_weights),
409 getTensorData<float>(aux_input_to_output_weights),
410 getTensorData<float>(recurrent_to_input_weights),
411 getTensorData<float>(recurrent_to_forget_weights),
412 getTensorData<float>(recurrent_to_cell_weights),
413 getTensorData<float>(recurrent_to_output_weights),
414 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
415 getTensorData<float>(cell_to_output_weights),
416 getTensorData<float>(input_layer_norm_coefficients),
417 getTensorData<float>(forget_layer_norm_coefficients),
418 getTensorData<float>(cell_layer_norm_coefficients),
419 getTensorData<float>(output_layer_norm_coefficients),
420 getTensorData<float>(input_gate_bias), getTensorData<float>(forget_gate_bias),
421 getTensorData<float>(cell_gate_bias), getTensorData<float>(output_gate_bias),
422 getTensorData<float>(projection_weights), getTensorData<float>(projection_bias), params,
423 1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
424 output_state_ptr, cell_state_ptr, input_gate_scratch_ptr, forget_gate_scratch_ptr,
425 cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);
426 }
427 }
428 }
429}
const Shape & shape() const
const luci_interpreter::RuntimeShape output_shape
void LstmStepFloat(const float *input_ptr, const float *input_to_input_weights_ptr, const float *input_to_forget_weights_ptr, const float *input_to_cell_weights_ptr, const float *input_to_output_weights_ptr, const float *aux_input_ptr, const float *aux_input_to_input_weights_ptr, const float *aux_input_to_forget_weights_ptr, const float *aux_input_to_cell_weights_ptr, const float *aux_input_to_output_weights_ptr, const float *recurrent_to_input_weights_ptr, const float *recurrent_to_forget_weights_ptr, const float *recurrent_to_cell_weights_ptr, const float *recurrent_to_output_weights_ptr, const float *cell_to_input_weights_ptr, const float *cell_to_forget_weights_ptr, const float *cell_to_output_weights_ptr, const float *input_layer_norm_coefficients_ptr, const float *forget_layer_norm_coefficients_ptr, const float *cell_layer_norm_coefficients_ptr, const float *output_layer_norm_coefficients_ptr, const float *input_gate_bias_ptr, const float *forget_gate_bias_ptr, const float *cell_gate_bias_ptr, const float *output_gate_bias_ptr, const float *projection_weights_ptr, const float *projection_bias_ptr, const LSTMParams *params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float *output_state_ptr, float *cell_state_ptr, float *scratch0, float *scratch1, float *scratch2, float *scratch3, float *output_ptr)