ONE - On-device Neural Engine
Loading...
Searching...
No Matches
UnidirectionalSequenceLSTM.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/UnidirectionalSequenceLSTM.h"
19#include "kernels/Utils.h"
20
21#include <tensorflow/lite/kernels/internal/tensor_utils.h>
22
23namespace luci_interpreter
24{
25namespace kernels
26{
27namespace lstm
28{
29namespace
30{
31
32using namespace tflite;
33
34void UpdateLstmCellFloat(int n_batch, int n_cell, float *cell_state, const float *input_gate,
35 float *forget_gate, const float *cell_gate, bool use_cifg, float clip)
36{
37// NOTE tflite source is as is but will fail build with gcc-8 and above
38// TODO remove #pragma
39#pragma GCC diagnostic ignored "-Wrestrict"
40 tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state, n_batch * n_cell, cell_state);
41
42 if (use_cifg)
43 {
44 // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
45 // scratch, as input_gate array is not allocated in this case. (Be careful
46 // not to write to the scratch before reading the forget gate data.)
47 float *scratch = forget_gate;
48 tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
49 tensor_utils::VectorVectorCwiseProductAccumulate(cell_gate, scratch, n_batch * n_cell,
50 cell_state);
51 }
52 else
53 {
54 tensor_utils::VectorVectorCwiseProductAccumulate(cell_gate, input_gate, n_batch * n_cell,
55 cell_state);
56 }
57 if (clip > 0.0f)
58 {
59 tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
60 }
61}
62
63void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output, const float *cell_state,
64 const float *output_gate, TfLiteFusedActivation activation,
65 const float *projection_weights, const float *projection_bias,
66 const float proj_clip, float *output_state, float *scratch)
67{
68 tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell, activation, scratch);
69 tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell, scratch);
70
71 const bool use_projection = (projection_weights != nullptr);
72 const bool use_projection_bias = (projection_bias != nullptr);
73
74 if (use_projection)
75 {
76 if (use_projection_bias)
77 {
78 tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch, output_state);
79 }
80 else
81 {
82 std::fill_n(output_state, n_batch * n_output, 0.0f);
83 }
84 tensor_utils::MatrixBatchVectorMultiplyAccumulate(projection_weights, n_output, n_cell, scratch,
85 n_batch, output_state);
86 if (proj_clip > 0.0f)
87 {
88 tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
89 }
90 }
91 else
92 {
93 std::copy_n(scratch, n_batch * n_output, output_state);
94 }
95}
96
97inline void CalculateLstmGateFloat(const float *input, const float *input_to_gate_weights,
98 const float *aux_input, const float *aux_input_to_gate_weights,
99 const float *output_state,
100 const float *recurrent_to_gate_weights, const float *cell_state,
101 const float *cell_to_gate_weights,
102 const float *layer_norm_coefficients, const float *gate_bias,
103 const int n_batch, const int n_input, const int n_aux_input,
104 const int n_output, const int n_cell,
105 const TfLiteFusedActivation activation, float *gate,
106 const bool is_input_all_zeros, const bool is_aux_input_all_zeros)
107{
108 const bool use_peephole = (cell_to_gate_weights != nullptr);
109 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
110
111 // Initialize scratch buffers with bias for regular lstm or initialize with
112 // zero for layer norm lstm.
113 if (use_layer_norm)
114 {
115 std::fill_n(gate, n_cell * n_batch, 0.0f);
116 }
117 else
118 {
119 tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
120 }
121 // For each batch and cell: compute input_weight * input.
122 // Skip if input is all zeros.
123 if (!is_input_all_zeros)
124 {
125 tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_gate_weights, n_cell, n_input, input,
126 n_batch, gate);
127 }
128 // For each batch and cell: compute aux_input_weight * aux_input.
129 // Skip if auxiliary input is not available or all zeros.
130 if (!is_aux_input_all_zeros)
131 {
132 tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, n_cell,
133 n_aux_input, aux_input, n_batch, gate);
134 }
135 // For each batch and cell: compute recurrent_weight * output_state.
136 tensor_utils::MatrixBatchVectorMultiplyAccumulate(recurrent_to_gate_weights, n_cell, n_output,
137 output_state, n_batch, gate);
138 // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
139 if (use_peephole)
140 {
141 tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_gate_weights, n_cell, cell_state,
142 n_batch, gate);
143 }
144 // Do layer normalization (if layer norm LSTM)
145 if (use_layer_norm)
146 {
147 tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
148 tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, gate, n_batch,
149 gate);
150 tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
151 }
152 // Apply activation
153 tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation, gate);
154}
155
156inline void LstmStepFloat(
157 const float *input_ptr, const float *input_to_input_weights_ptr,
158 const float *input_to_forget_weights_ptr, const float *input_to_cell_weights_ptr,
159 const float *input_to_output_weights_ptr, const float *aux_input_ptr,
160 const float *aux_input_to_input_weights_ptr, const float *aux_input_to_forget_weights_ptr,
161 const float *aux_input_to_cell_weights_ptr, const float *aux_input_to_output_weights_ptr,
162 const float *recurrent_to_input_weights_ptr, const float *recurrent_to_forget_weights_ptr,
163 const float *recurrent_to_cell_weights_ptr, const float *recurrent_to_output_weights_ptr,
164 const float *cell_to_input_weights_ptr, const float *cell_to_forget_weights_ptr,
165 const float *cell_to_output_weights_ptr, const float *input_layer_norm_coefficients_ptr,
166 const float *forget_layer_norm_coefficients_ptr, const float *cell_layer_norm_coefficients_ptr,
167 const float *output_layer_norm_coefficients_ptr, const float *input_gate_bias_ptr,
168 const float *forget_gate_bias_ptr, const float *cell_gate_bias_ptr,
169 const float *output_gate_bias_ptr, const float *projection_weights_ptr,
170 const float *projection_bias_ptr, const TfLiteLSTMParams *params, int n_batch, int n_cell,
171 int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float *output_state_ptr,
172 float *cell_state_ptr, float *scratch0, float *scratch1, float *scratch2, float *scratch3,
173 float *output_ptr)
174{
175 // Since we have already checked that weights are all there or none, we can
176 // check the existence of only one to the get the condition.
177 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
178
179 // Make named scratch buffers.
180 float *input_gate_scratch = scratch0;
181 float *forget_gate_scratch = scratch1;
182 float *cell_gate_scratch = scratch2;
183 float *output_gate_scratch = scratch3;
184
185 // Check if inputs are all zeros so we can skip some computations.
186 const bool is_input_all_zeros = tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
187 const bool is_aux_input_all_zeros =
188 (aux_input_ptr == nullptr || tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
189 if (!use_cifg)
190 {
191 // Calculate the input gate. (If not CIFG.)
192 CalculateLstmGateFloat(input_ptr, input_to_input_weights_ptr, aux_input_ptr,
193 aux_input_to_input_weights_ptr, output_state_ptr,
194 recurrent_to_input_weights_ptr, cell_state_ptr,
195 cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
196 input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
197 /*activation=*/kTfLiteActSigmoid, input_gate_scratch, is_input_all_zeros,
198 is_aux_input_all_zeros);
199 }
200 // Calculate the forget gate.
201 CalculateLstmGateFloat(input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
202 aux_input_to_forget_weights_ptr, output_state_ptr,
203 recurrent_to_forget_weights_ptr, cell_state_ptr,
204 cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
205 forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
206 /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
207 is_aux_input_all_zeros);
208 // Calculate the cell update gate.
209 CalculateLstmGateFloat(
210 input_ptr, input_to_cell_weights_ptr, aux_input_ptr, aux_input_to_cell_weights_ptr,
211 output_state_ptr, recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
212 /*cell_to_gate_weights=*/nullptr, cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, n_batch,
213 n_input, n_aux_input, n_output, n_cell, params->activation, cell_gate_scratch,
214 is_input_all_zeros, is_aux_input_all_zeros);
215 // Update the cell state.
216 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
217 cell_gate_scratch, use_cifg, params->cell_clip);
218 // Calculate output gate.
219 CalculateLstmGateFloat(input_ptr, input_to_output_weights_ptr, aux_input_ptr,
220 aux_input_to_output_weights_ptr, output_state_ptr,
221 recurrent_to_output_weights_ptr, cell_state_ptr,
222 cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
223 output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
224 /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
225 is_aux_input_all_zeros);
226 // Update the output state.
227 CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
228 params->activation, projection_weights_ptr, projection_bias_ptr,
229 params->proj_clip, output_state_ptr, scratch2);
230 // Copy output state to the output. Note that the output's rows may not be
231 // contiguous (output_batch_leading_dim != n_output).
232 for (int b = 0; b < n_batch; b++)
233 {
234 std::copy_n(output_state_ptr + b * n_output, n_output,
235 output_ptr + b * output_batch_leading_dim);
236 }
237}
238
239} // namespace
240
241void EvalFloat(const Tensor *input,
242
243 const Tensor *input_to_input_weights, const Tensor *input_to_forget_weights,
244 const Tensor *input_to_cell_weights, const Tensor *input_to_output_weights,
245
246 const Tensor *recurrent_to_input_weights, const Tensor *recurrent_to_forget_weights,
247 const Tensor *recurrent_to_cell_weights, const Tensor *recurrent_to_output_weights,
248
249 const Tensor *cell_to_input_weights, const Tensor *cell_to_forget_weights,
250 const Tensor *cell_to_output_weights,
251
252 const Tensor *input_layer_norm_coefficients,
253 const Tensor *forget_layer_norm_coefficients,
254 const Tensor *cell_layer_norm_coefficients,
255 const Tensor *output_layer_norm_coefficients,
256
257 const Tensor *aux_input, const Tensor *aux_input_to_input_weights,
258 const Tensor *aux_input_to_forget_weights, const Tensor *aux_input_to_cell_weights,
259 const Tensor *aux_input_to_output_weights,
260
261 const Tensor *input_gate_bias, const Tensor *forget_gate_bias,
262 const Tensor *cell_gate_bias, const Tensor *output_gate_bias,
263
264 const Tensor *projection_weights, const Tensor *projection_bias,
265 const TfLiteLSTMParams *params,
266
267 bool forward_sequence, bool time_major, int output_offset,
268
269 Tensor *scratch_buffer, Tensor *output_state, Tensor *cell_state, Tensor *output)
270{
271 const Shape &input_shape = input->shape();
272 assert(input_shape.num_dims() >= 2 && input_shape.num_dims() <= 3);
273 int max_time, n_batch;
274 if (input_shape.num_dims() == 3)
275 {
276 max_time = (time_major) ? input_shape.dim(0) : input_shape.dim(1);
277 n_batch = (time_major) ? input_shape.dim(1) : input_shape.dim(0);
278 }
279 else
280 {
281 max_time = 1;
282 n_batch = input_shape.dim(0);
283 }
284 const int n_input = input_shape.dim(input_shape.num_dims() - 1);
285
286 int aux_input_temp = 0;
287 if (aux_input)
288 {
289 const Shape &aux_input_shape = aux_input->shape();
290 aux_input_temp = aux_input_shape.dim(aux_input_shape.num_dims() - 1);
291 }
292 const int aux_input_size = aux_input_temp;
293
294 // n_cell and n_output will be the same size when there is no projection.
295 const Shape &input_to_output_weights_shape = input_to_output_weights->shape();
296 const Shape &recurrent_to_output_weights_shape = recurrent_to_output_weights->shape();
297 const int n_cell = input_to_output_weights_shape.dim(0);
298 const int n_output = recurrent_to_output_weights_shape.dim(1);
299
300 // Since we have already checked that weights are all there or none, we can
301 // check the existence of only one to the get the condition.
302 const bool use_cifg = (input_to_input_weights == nullptr);
303
304 // Index the scratch buffers pointers to the global scratch buffer.
305 float *scratch_buffer_ptr = getTensorData<float>(scratch_buffer);
306 float *input_gate_scratch = nullptr;
307 float *cell_gate_scratch = nullptr;
308 float *forget_gate_scratch = nullptr;
309 float *output_gate_scratch = nullptr;
310 if (use_cifg)
311 {
312 cell_gate_scratch = scratch_buffer_ptr;
313 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
314 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
315 }
316 else
317 {
318 input_gate_scratch = scratch_buffer_ptr;
319 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
320 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
321 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
322 }
323
324 const Shape &output_shape = output->shape();
325 const int output_batch_leading_dim = output_shape.dim(output_shape.num_dims() - 1);
326 if (time_major)
327 {
328 // Loop through the sequence.
329 const int input_step = n_batch * n_input;
330 const int output_step = n_batch * output_batch_leading_dim;
331 for (int t = 0; t < max_time; t++)
332 {
333 // If this is the forward_sequence, step forward, otherwise step
334 // backwards.
335 const int t_rel = forward_sequence ? t : max_time - t - 1;
336 const float *input_ptr = getTensorData<float>(input) + t_rel * input_step;
337 const float *aux_input_ptr = nullptr;
338 if (aux_input)
339 {
340 aux_input_ptr = getTensorData<float>(aux_input) + t_rel * input_step;
341 }
342 float *output_ptr = getTensorData<float>(output) + t_rel * output_step + output_offset;
343
344 LstmStepFloat(
345 input_ptr, getTensorData<float>(input_to_input_weights),
346 getTensorData<float>(input_to_forget_weights), getTensorData<float>(input_to_cell_weights),
347 getTensorData<float>(input_to_output_weights), aux_input_ptr,
348 getTensorData<float>(aux_input_to_input_weights),
349 getTensorData<float>(aux_input_to_forget_weights),
350 getTensorData<float>(aux_input_to_cell_weights),
351 getTensorData<float>(aux_input_to_output_weights),
352 getTensorData<float>(recurrent_to_input_weights),
353 getTensorData<float>(recurrent_to_forget_weights),
354 getTensorData<float>(recurrent_to_cell_weights),
355 getTensorData<float>(recurrent_to_output_weights),
356 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
357 getTensorData<float>(cell_to_output_weights),
358 getTensorData<float>(input_layer_norm_coefficients),
359 getTensorData<float>(forget_layer_norm_coefficients),
360 getTensorData<float>(cell_layer_norm_coefficients),
361 getTensorData<float>(output_layer_norm_coefficients), getTensorData<float>(input_gate_bias),
362 getTensorData<float>(forget_gate_bias), getTensorData<float>(cell_gate_bias),
363 getTensorData<float>(output_gate_bias), getTensorData<float>(projection_weights),
364 getTensorData<float>(projection_bias), params, n_batch, n_cell, n_input, aux_input_size,
365 n_output, output_batch_leading_dim, getTensorData<float>(output_state),
366 getTensorData<float>(cell_state), input_gate_scratch, forget_gate_scratch,
367 cell_gate_scratch, output_gate_scratch, output_ptr);
368 }
369 }
370 else
371 {
372 for (int b = 0; b < n_batch; b++)
373 {
374 const int input_step = n_input;
375 const int output_step = output_batch_leading_dim;
376 for (int t = 0; t < max_time; t++)
377 {
378 // If this is the forward_sequence, step forward, otherwise step
379 // backwards.
380 const int t_rel = forward_sequence ? t : max_time - t - 1;
381 const int time_offset = b * max_time + t_rel;
382 const float *input_ptr = getTensorData<float>(input) + time_offset * input_step;
383 const float *aux_input_ptr = nullptr;
384 if (aux_input)
385 {
386 aux_input_ptr = getTensorData<float>(aux_input) + time_offset * input_step;
387 }
388 float *output_ptr =
389 getTensorData<float>(output) + time_offset * output_step + output_offset;
390
391 // Offset the {output,cell}_state pointers to the right batch.
392 float *output_state_ptr = getTensorData<float>(output_state) + b * output_batch_leading_dim;
393 float *cell_state_ptr = getTensorData<float>(cell_state) + b * n_cell;
394 // Offset the scratch pointers to the right batch.
395 float *input_gate_scratch_ptr =
396 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
397 float *forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
398 float *cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
399 float *output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
400
401 LstmStepFloat(
402 input_ptr, getTensorData<float>(input_to_input_weights),
403 getTensorData<float>(input_to_forget_weights),
404 getTensorData<float>(input_to_cell_weights),
405 getTensorData<float>(input_to_output_weights), aux_input_ptr,
406 getTensorData<float>(aux_input_to_input_weights),
407 getTensorData<float>(aux_input_to_forget_weights),
408 getTensorData<float>(aux_input_to_cell_weights),
409 getTensorData<float>(aux_input_to_output_weights),
410 getTensorData<float>(recurrent_to_input_weights),
411 getTensorData<float>(recurrent_to_forget_weights),
412 getTensorData<float>(recurrent_to_cell_weights),
413 getTensorData<float>(recurrent_to_output_weights),
414 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
415 getTensorData<float>(cell_to_output_weights),
416 getTensorData<float>(input_layer_norm_coefficients),
417 getTensorData<float>(forget_layer_norm_coefficients),
418 getTensorData<float>(cell_layer_norm_coefficients),
419 getTensorData<float>(output_layer_norm_coefficients),
420 getTensorData<float>(input_gate_bias), getTensorData<float>(forget_gate_bias),
421 getTensorData<float>(cell_gate_bias), getTensorData<float>(output_gate_bias),
422 getTensorData<float>(projection_weights), getTensorData<float>(projection_bias), params,
423 /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
424 output_state_ptr, cell_state_ptr, input_gate_scratch_ptr, forget_gate_scratch_ptr,
425 cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);
426 }
427 }
428 }
429}
430
431} // namespace lstm
432} // namespace kernels
433} // namespace luci_interpreter
434
435namespace luci_interpreter
436{
437namespace kernels
438{
439
441 const Tensor *input,
442
443 const Tensor *input_to_input_weights, const Tensor *input_to_forget_weights,
444 const Tensor *input_to_cell_weights, const Tensor *input_to_output_weights,
445
446 const Tensor *recurrent_to_input_weights, const Tensor *recurrent_to_forget_weights,
447 const Tensor *recurrent_to_cell_weights, const Tensor *recurrent_to_output_weights,
448
449 const Tensor *cell_to_input_weights, const Tensor *cell_to_forget_weights,
450 const Tensor *cell_to_output_weights,
451
452 const Tensor *input_gate_bias, const Tensor *forget_gate_bias, const Tensor *cell_gate_bias,
453 const Tensor *output_gate_bias,
454
455 const Tensor *projection_weights, const Tensor *projection_bias,
456
457 const Tensor *output_state, const Tensor *cell_state, const Tensor *input_layer_norm_coefficients,
458 const Tensor *forget_layer_norm_coefficients, const Tensor *cell_layer_norm_coefficients,
459 const Tensor *output_layer_norm_coefficients,
460
461 Tensor *output, Tensor *scratchpad_1, Tensor *scratchpad_2, Tensor *scratchpad_3,
464 {input,
469
474
478
483
486
489
494 {output, scratchpad_1, scratchpad_2, scratchpad_3}, params)
495{
496 // Do nothing
497}
498
499// Check that input tensor dimensions matches with each other.
500void UnidirectionalSequenceLSTM::check_input_tensor_dimensions(int n_input, int n_output,
501 int n_cell, bool use_layer_norm,
502 bool is_integer)
503{
504 // Making sure clipping parameters have valid values.
505 // == 0 means no clipping
506 // > 0 means clipping
507 LUCI_INTERPRETER_CHECK(params().cell_clip >= 0);
508 LUCI_INTERPRETER_CHECK(params().proj_clip >= 0);
509
510 if (input_to_input_weights() != nullptr)
511 {
512 const Shape &input_to_input_weights_shape = input_to_input_weights()->shape();
513 LUCI_INTERPRETER_CHECK(input_to_input_weights_shape.num_dims() == 2);
514 LUCI_INTERPRETER_CHECK(input_to_input_weights_shape.dim(0) == n_cell);
515 LUCI_INTERPRETER_CHECK(input_to_input_weights_shape.dim(1) == n_input);
516 }
517
518 const Shape &input_to_forget_weights_shape = input_to_forget_weights()->shape();
519 LUCI_INTERPRETER_CHECK(input_to_forget_weights_shape.num_dims() == 2);
520 LUCI_INTERPRETER_CHECK(input_to_forget_weights_shape.dim(0) == n_cell);
521 LUCI_INTERPRETER_CHECK(input_to_forget_weights_shape.dim(1) == n_input);
522
523 const Shape &input_to_cell_weights_shape = input_to_cell_weights()->shape();
524 LUCI_INTERPRETER_CHECK(input_to_cell_weights_shape.num_dims() == 2);
525 LUCI_INTERPRETER_CHECK(input_to_cell_weights_shape.dim(0) == n_cell);
526 LUCI_INTERPRETER_CHECK(input_to_cell_weights_shape.dim(1) == n_input);
527
528 if (recurrent_to_input_weights() != nullptr)
529 {
530 const Shape &recurrent_to_input_weights_shape = recurrent_to_input_weights()->shape();
531 LUCI_INTERPRETER_CHECK(recurrent_to_input_weights_shape.num_dims() == 2);
532 LUCI_INTERPRETER_CHECK(recurrent_to_input_weights_shape.dim(0) == n_cell);
533 LUCI_INTERPRETER_CHECK(recurrent_to_input_weights_shape.dim(1) == n_output);
534 }
535
536 const Shape &recurrent_to_forget_weights_shape = recurrent_to_forget_weights()->shape();
537 LUCI_INTERPRETER_CHECK(recurrent_to_forget_weights_shape.num_dims() == 2);
538 LUCI_INTERPRETER_CHECK(recurrent_to_forget_weights_shape.dim(0) == n_cell);
539 LUCI_INTERPRETER_CHECK(recurrent_to_forget_weights_shape.dim(1) == n_output);
540
541 const Shape &recurrent_to_cell_weights_shape = recurrent_to_cell_weights()->shape();
542 LUCI_INTERPRETER_CHECK(recurrent_to_cell_weights_shape.num_dims() == 2);
543 LUCI_INTERPRETER_CHECK(recurrent_to_cell_weights_shape.dim(0) == n_cell);
544 LUCI_INTERPRETER_CHECK(recurrent_to_cell_weights_shape.dim(1) == n_output);
545
546 // We make sure the input-gate's parameters are either both present (regular
547 // LSTM) or not at all (CIFG-LSTM).
548 const bool cifg_weights_all_or_none =
549 ((input_to_input_weights() != nullptr) && (recurrent_to_input_weights() != nullptr)) ||
550 ((input_to_input_weights() == nullptr) && (recurrent_to_input_weights() == nullptr));
551 LUCI_INTERPRETER_CHECK(cifg_weights_all_or_none == true);
552
553 if (cell_to_input_weights() != nullptr)
554 {
555 const Shape &cell_to_input_weights_shape = cell_to_input_weights()->shape();
556 LUCI_INTERPRETER_CHECK(cell_to_input_weights_shape.num_dims() == 1);
557 LUCI_INTERPRETER_CHECK(cell_to_input_weights_shape.dim(0) == n_cell);
558 LUCI_INTERPRETER_CHECK(is_integer
559 ? cell_to_input_weights()->element_type() == loco::DataType::S16
560 : cell_to_input_weights()->element_type() ==
561 input_to_forget_weights()->element_type());
562 }
563
564 if (cell_to_forget_weights() != nullptr)
565 {
566 const Shape &cell_to_forget_weights_shape = cell_to_forget_weights()->shape();
567 LUCI_INTERPRETER_CHECK(cell_to_forget_weights_shape.num_dims() == 1);
568 LUCI_INTERPRETER_CHECK(cell_to_forget_weights_shape.dim(0) == n_cell);
569 LUCI_INTERPRETER_CHECK(is_integer
570 ? cell_to_forget_weights()->element_type() == loco::DataType::S16
571 : cell_to_forget_weights()->element_type() ==
572 input_to_forget_weights()->element_type());
573 }
574
575 if (cell_to_output_weights() != nullptr)
576 {
577 const Shape &cell_to_output_weights_shape = cell_to_output_weights()->shape();
578 LUCI_INTERPRETER_CHECK(cell_to_output_weights_shape.num_dims() == 1);
579 LUCI_INTERPRETER_CHECK(cell_to_output_weights_shape.dim(0) == n_cell);
580 LUCI_INTERPRETER_CHECK(is_integer
581 ? cell_to_output_weights()->element_type() == loco::DataType::S16
582 : cell_to_output_weights()->element_type() ==
583 input_to_forget_weights()->element_type());
584 }
585
586 // Making sure the peephole weights are there all or none.
587 const bool use_cifg = (input_to_input_weights() == nullptr);
588 const bool peephole_weights_all_or_none =
589 ((cell_to_input_weights() != nullptr || use_cifg) && (cell_to_forget_weights() != nullptr) &&
590 (cell_to_output_weights() != nullptr)) ||
591 ((cell_to_input_weights() == nullptr) && (cell_to_forget_weights() == nullptr) &&
592 (cell_to_output_weights() == nullptr));
593 LUCI_INTERPRETER_CHECK(peephole_weights_all_or_none == true);
594
595 // Make sure the input gate bias is present only when not a CIFG-LSTM.
596 if (use_cifg)
597 {
599 }
600 else
601 {
602 const Shape &input_gate_bias_shape = input_gate_bias()->shape();
603 LUCI_INTERPRETER_CHECK(input_gate_bias_shape.num_dims() == 1);
604 LUCI_INTERPRETER_CHECK(input_gate_bias_shape.dim(0) == n_cell);
605 if (is_integer)
606 {
607 LUCI_INTERPRETER_CHECK(input_gate_bias()->element_type() == loco::DataType::S32);
608 }
609 else
610 {
611 LUCI_INTERPRETER_CHECK(input_gate_bias()->element_type() == loco::DataType::FLOAT32);
612 }
613 }
614
615 const Shape &forget_gate_bias_shape = forget_gate_bias()->shape();
616 LUCI_INTERPRETER_CHECK(forget_gate_bias_shape.num_dims() == 1);
617 LUCI_INTERPRETER_CHECK(forget_gate_bias_shape.dim(0) == n_cell);
618 if (is_integer)
619 {
620 LUCI_INTERPRETER_CHECK(forget_gate_bias()->element_type() == loco::DataType::S32);
621 }
622 else
623 {
624 LUCI_INTERPRETER_CHECK(forget_gate_bias()->element_type() == loco::DataType::FLOAT32);
625 }
626
627 const Shape &cell_gate_bias_shape = cell_gate_bias()->shape();
628 LUCI_INTERPRETER_CHECK(cell_gate_bias_shape.num_dims() == 1);
629 LUCI_INTERPRETER_CHECK(cell_gate_bias_shape.dim(0) == n_cell);
630 if (is_integer)
631 {
632 LUCI_INTERPRETER_CHECK(cell_gate_bias()->element_type() == loco::DataType::S32);
633 }
634 else
635 {
636 LUCI_INTERPRETER_CHECK(cell_gate_bias()->element_type() == loco::DataType::FLOAT32);
637 }
638
639 const Shape &output_gate_bias_shape = output_gate_bias()->shape();
640 LUCI_INTERPRETER_CHECK(output_gate_bias_shape.num_dims() == 1);
641 LUCI_INTERPRETER_CHECK(output_gate_bias_shape.dim(0) == n_cell);
642 if (is_integer)
643 {
644 LUCI_INTERPRETER_CHECK(output_gate_bias()->element_type() == loco::DataType::S32);
645 }
646 else
647 {
648 LUCI_INTERPRETER_CHECK(output_gate_bias()->element_type() == loco::DataType::FLOAT32);
649 }
650
651 if (projection_weights() != nullptr)
652 {
653 const Shape &projection_weights_shape = projection_weights()->shape();
654 LUCI_INTERPRETER_CHECK(projection_weights_shape.num_dims() == 2);
655 LUCI_INTERPRETER_CHECK(projection_weights_shape.dim(0) == n_output);
656 LUCI_INTERPRETER_CHECK(projection_weights_shape.dim(1) == n_cell);
657 }
658
659 if (projection_bias() != nullptr)
660 {
661 const Shape &projection_bias_shape = projection_bias()->shape();
662 LUCI_INTERPRETER_CHECK(projection_bias_shape.num_dims() == 1);
663 LUCI_INTERPRETER_CHECK(projection_bias_shape.dim(0) == n_output);
664 if (is_integer)
665 {
666 LUCI_INTERPRETER_CHECK(projection_bias()->element_type() == loco::DataType::S32);
667 }
668 else
669 {
670 LUCI_INTERPRETER_CHECK(projection_bias()->element_type() == loco::DataType::FLOAT32);
671 }
672 }
673
674 // Making sure the projection tensors are consistent:
675 // 1) If projection weight is not present, then projection bias should not be
676 // present.
677 // 2) If projection weight is present, then projection bias is optional.
678 // TODO(ghodrat): make sure this is correct.
679 const bool projecton_tensors_consistent =
680 ((projection_weights() != nullptr) || (projection_bias() == nullptr));
681 LUCI_INTERPRETER_CHECK(projecton_tensors_consistent == true);
682
683 if (use_layer_norm)
684 {
685 if (use_cifg)
686 {
688 }
689 else
690 {
692
693 const Shape &input_layer_norm_coefficients_shape = input_layer_norm_coefficients()->shape();
694 LUCI_INTERPRETER_CHECK(input_layer_norm_coefficients_shape.num_dims() == 1);
695 LUCI_INTERPRETER_CHECK(input_layer_norm_coefficients_shape.dim(0) == n_cell);
696 if (is_integer)
697 {
699 loco::DataType::S16);
700 }
701 else
702 {
704 loco::DataType::FLOAT32);
705 }
706 }
707
708 const Shape &forget_layer_norm_coefficients_shape = forget_layer_norm_coefficients()->shape();
709 LUCI_INTERPRETER_CHECK(forget_layer_norm_coefficients_shape.num_dims() == 1);
710 LUCI_INTERPRETER_CHECK(forget_layer_norm_coefficients_shape.dim(0) == n_cell);
711 if (is_integer)
712 {
714 loco::DataType::S16);
715 }
716 else
717 {
719 loco::DataType::FLOAT32);
720 }
721
722 const Shape &cell_layer_norm_coefficients_shape = cell_layer_norm_coefficients()->shape();
723 LUCI_INTERPRETER_CHECK(cell_layer_norm_coefficients_shape.num_dims() == 1);
724 LUCI_INTERPRETER_CHECK(cell_layer_norm_coefficients_shape.dim(0) == n_cell);
725 if (is_integer)
726 {
727 LUCI_INTERPRETER_CHECK(cell_layer_norm_coefficients()->element_type() == loco::DataType::S16);
728 }
729 else
730 {
732 loco::DataType::FLOAT32);
733 }
734
735 const Shape &output_layer_norm_coefficients_shape = output_layer_norm_coefficients()->shape();
736 LUCI_INTERPRETER_CHECK(output_layer_norm_coefficients_shape.num_dims() == 1);
737 LUCI_INTERPRETER_CHECK(output_layer_norm_coefficients_shape.dim(0) == n_cell);
738 if (is_integer)
739 {
741 loco::DataType::S16);
742 }
743 else
744 {
746 loco::DataType::FLOAT32);
747 }
748 }
749}
750
752{
755
756 // TODO support U8
757 LUCI_INTERPRETER_CHECK(input()->element_type() == loco::DataType::FLOAT32);
758 const bool is_integer = false;
759 const bool use_layer_norm = (forget_layer_norm_coefficients() != nullptr);
760
761 // Inferring batch size, number of outputs and sequence length and
762 // number of cells from the input tensors.
763 const Shape &input_shape = input()->shape();
764 LUCI_INTERPRETER_CHECK(input_shape.num_dims() > 1);
765 const bool time_major = params().time_major;
766 const int n_batch = time_major ? input_shape.dim(1) : input_shape.dim(0);
767 // NOTE as dim(2) is accessed, we need to check this is valid
768 LUCI_INTERPRETER_CHECK(input_shape.num_dims() > 2);
769 const int n_input = input_shape.dim(2);
770
771 const Shape &input_to_output_weights_shape = input_to_output_weights()->shape();
772 const int n_cell = input_to_output_weights_shape.dim(0);
773 LUCI_INTERPRETER_CHECK(input_to_output_weights_shape.num_dims() == 2);
774 LUCI_INTERPRETER_CHECK(input_to_output_weights_shape.dim(1) == n_input);
775
776 const Shape &recurrent_to_output_weights_shape = recurrent_to_output_weights()->shape();
777 LUCI_INTERPRETER_CHECK(recurrent_to_output_weights_shape.num_dims() == 2);
778 LUCI_INTERPRETER_CHECK(recurrent_to_output_weights_shape.dim(0) == n_cell);
779
780 const int n_output = recurrent_to_output_weights_shape.dim(1);
781
782 // Check that input tensor dimensions matches with each other.
783 check_input_tensor_dimensions(n_input, n_output, n_cell, use_layer_norm, is_integer);
784
785 // Check the shape of input state tensors.
786 // These tensor may be 1D or 2D. It's fine as long as the total size is
787 // correct.
788 const Shape &output_state_shape = output_state()->shape();
789 const Shape &cell_state_shape = cell_state()->shape();
790 LUCI_INTERPRETER_CHECK(output_state_shape.num_elements() == n_batch * n_output);
791 LUCI_INTERPRETER_CHECK(cell_state_shape.num_elements() == n_batch * n_cell);
792
793 // Resize the output tensors.
794 Shape output_shape = Shape(input_shape.num_dims());
795 for (int i = 0; i < input_shape.num_dims() - 1; i++)
796 {
797 output_shape.dim(i) = input_shape.dim(i);
798 }
799 output_shape.dim(input_shape.num_dims() - 1) = n_output;
801
802 // TODO import integer
803
804 // output_state and cell_state are variable tensor; use scratchpad.
805 getOutputTensors()[1]->resize(output_state_shape);
806 getOutputTensors()[2]->resize(cell_state_shape);
807
808 const bool use_cifg = (input_to_input_weights() == nullptr);
809 if (use_cifg)
810 getOutputTensors()[3]->resize({n_batch, n_cell * 3});
811 else
812 getOutputTensors()[3]->resize({n_batch, n_cell * 4});
813
814 // hybrid not supported
815 if (input_to_output_weights()->element_type() == loco::DataType::U8 &&
816 input()->element_type() == loco::DataType::FLOAT32)
817 {
818 throw std::runtime_error("Hybrid type is not currently supported");
819 }
820 // TODO support hybrid
821 // TODO support U8
822}
823
825{
826 switch (input()->element_type())
827 {
828 case loco::DataType::FLOAT32:
829 evalFloat();
830 break;
831 default:
832 throw std::runtime_error("Unsupported type");
833 }
834}
835
836void UnidirectionalSequenceLSTM::evalFloat() const
837{
838 const bool time_major = params().time_major;
839 const bool use_layer_norm = (forget_layer_norm_coefficients() != nullptr);
840
841 const Tensor *t_input_layer_norm_coefficients =
842 use_layer_norm ? input_layer_norm_coefficients() : nullptr;
843 const Tensor *t_forget_layer_norm_coefficients =
844 use_layer_norm ? forget_layer_norm_coefficients() : nullptr;
845 const Tensor *t_cell_layer_norm_coefficients =
846 use_layer_norm ? cell_layer_norm_coefficients() : nullptr;
847 const Tensor *t_output_layer_norm_coefficients =
848 use_layer_norm ? output_layer_norm_coefficients() : nullptr;
849
850 Tensor *sp_output_state = getOutputTensors()[1];
851 Tensor *sp_cell_state = getOutputTensors()[2];
852 Tensor *sp_scratch_buffer = getOutputTensors()[3];
853
854 // Note: it is expected that output_state input variable tensor reset to zero,
855 // also expected that this variable tensor doesn't have buffer
856 auto scratchpad_data = getTensorData<float>(sp_output_state);
857 std::fill_n(scratchpad_data, sp_output_state->shape().num_elements(), 0);
858 scratchpad_data = getTensorData<float>(sp_cell_state);
859 std::fill_n(scratchpad_data, sp_cell_state->shape().num_elements(), 0);
860 scratchpad_data = getTensorData<float>(sp_scratch_buffer);
861 std::fill_n(scratchpad_data, sp_scratch_buffer->shape().num_elements(), 0);
862
863 TfLiteLSTMParams lstm_params{};
864 lstm_params.activation = getTfLiteActivation(params().activation);
865 lstm_params.cell_clip = params().cell_clip;
866 lstm_params.proj_clip = params().proj_clip;
867 lstm_params.asymmetric_quantize_inputs = params().asymmetric_quantize_inputs;
868
871
874
876
877 t_input_layer_norm_coefficients, t_forget_layer_norm_coefficients,
878 t_cell_layer_norm_coefficients, t_output_layer_norm_coefficients,
879 /*aux_input=*/nullptr,
880 /*aux_input_to_input_weights=*/nullptr,
881 /*aux_input_to_forget_weights=*/nullptr,
882 /*aux_input_to_cell_weights=*/nullptr,
883 /*aux_input_to_output_weights=*/nullptr, input_gate_bias(), forget_gate_bias(),
885
886 projection_weights(), projection_bias(), &lstm_params,
887 /*forward_sequence=*/true, time_major,
888 /*output_offset=*/0, sp_scratch_buffer, sp_output_state, sp_cell_state, output());
889}
890
891} // namespace kernels
892} // namespace luci_interpreter
const std::vector< Tensor * > & getOutputTensors() const
Definition Kernel.h:40
const std::vector< const Tensor * > & getInputTensors() const
Definition Kernel.h:39
const UnidirectionalSequenceLSTMParams & params() const
Definition Kernel.h:67
int32_t dim(int i) const
Definition Tensor.h:41
int32_t num_elements() const
Definition Tensor.h:53
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
UnidirectionalSequenceLSTM(const Tensor *input, const Tensor *input_to_input_weights, const Tensor *input_to_forget_weights, const Tensor *input_to_cell_weights, const Tensor *input_to_output_weights, const Tensor *recurrent_to_input_weights, const Tensor *recurrent_to_forget_weights, const Tensor *recurrent_to_cell_weights, const Tensor *recurrent_to_output_weights, const Tensor *cell_to_input_weights, const Tensor *cell_to_forget_weights, const Tensor *cell_to_output_weights, const Tensor *input_gate_bias, const Tensor *forget_gate_bias, const Tensor *cell_gate_bias, const Tensor *output_gate_bias, const Tensor *projection_weights, const Tensor *projection_bias, const Tensor *output_state, const Tensor *cell_state, const Tensor *input_layer_norm_coefficients, const Tensor *forget_layer_norm_coefficients, const Tensor *cell_layer_norm_coefficients, const Tensor *output_layer_norm_coefficients, Tensor *output, Tensor *scratchpad_1, Tensor *scratchpad_2, Tensor *scratchpad_3, const UnidirectionalSequenceLSTMParams &params)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
void EvalFloat(const Tensor *input, const Tensor *input_to_input_weights, const Tensor *input_to_forget_weights, const Tensor *input_to_cell_weights, const Tensor *input_to_output_weights, const Tensor *recurrent_to_input_weights, const Tensor *recurrent_to_forget_weights, const Tensor *recurrent_to_cell_weights, const Tensor *recurrent_to_output_weights, const Tensor *cell_to_input_weights, const Tensor *cell_to_forget_weights, const Tensor *cell_to_output_weights, const Tensor *input_layer_norm_coefficients, const Tensor *forget_layer_norm_coefficients, const Tensor *cell_layer_norm_coefficients, const Tensor *output_layer_norm_coefficients, const Tensor *aux_input, const Tensor *aux_input_to_input_weights, const Tensor *aux_input_to_forget_weights, const Tensor *aux_input_to_cell_weights, const Tensor *aux_input_to_output_weights, const Tensor *input_gate_bias, const Tensor *forget_gate_bias, const Tensor *cell_gate_bias, const Tensor *output_gate_bias, const Tensor *projection_weights, const Tensor *projection_bias, const TfLiteLSTMParams *params, bool forward_sequence, bool time_major, int output_offset, Tensor *scratch_buffer, Tensor *output_state, Tensor *cell_state, Tensor *output)
TfLiteFusedActivation getTfLiteActivation(Activation activation)
Definition Utils.cpp:30
int32_t size[5]
Definition Slice.cpp:35
Definition Shape.h:28