ONE - On-device Neural Engine
Loading...
Searching...
No Matches
UnidirectionalSequenceLSTM.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/UnidirectionalSequenceLSTM.h"
19#include "kernels/Utils.h"
20
21#include <tensorflow/lite/kernels/internal/tensor_utils.h>
22
23namespace luci_interpreter
24{
25namespace kernels
26{
27namespace lstm
28{
29namespace
30{
31
32using namespace tflite;
33
34void UpdateLstmCellFloat(int n_batch, int n_cell, float *cell_state, const float *input_gate,
35 float *forget_gate, const float *cell_gate, bool use_cifg, float clip)
36{
37 tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state, n_batch * n_cell, cell_state);
38
39 if (use_cifg)
40 {
41 // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
42 // scratch, as input_gate array is not allocated in this case. (Be careful
43 // not to write to the scratch before reading the forget gate data.)
44 float *scratch = forget_gate;
45 tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
46 tensor_utils::VectorVectorCwiseProductAccumulate(cell_gate, scratch, n_batch * n_cell,
47 cell_state);
48 }
49 else
50 {
51 tensor_utils::VectorVectorCwiseProductAccumulate(cell_gate, input_gate, n_batch * n_cell,
52 cell_state);
53 }
54 if (clip > 0.0f)
55 {
56 tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
57 }
58}
59
60void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output, const float *cell_state,
61 const float *output_gate, TfLiteFusedActivation activation,
62 const float *projection_weights, const float *projection_bias,
63 const float proj_clip, float *output_state, float *scratch)
64{
65 tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell, activation, scratch);
66 tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell, scratch);
67
68 const bool use_projection = (projection_weights != nullptr);
69 const bool use_projection_bias = (projection_bias != nullptr);
70
71 if (use_projection)
72 {
73 if (use_projection_bias)
74 {
75 tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch, output_state);
76 }
77 else
78 {
79 std::fill_n(output_state, n_batch * n_output, 0.0f);
80 }
81 tensor_utils::MatrixBatchVectorMultiplyAccumulate(projection_weights, n_output, n_cell, scratch,
82 n_batch, output_state);
83 if (proj_clip > 0.0f)
84 {
85 tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
86 }
87 }
88 else
89 {
90 std::copy_n(scratch, n_batch * n_output, output_state);
91 }
92}
93
94inline void CalculateLstmGateFloat(const float *input, const float *input_to_gate_weights,
95 const float *aux_input, const float *aux_input_to_gate_weights,
96 const float *output_state,
97 const float *recurrent_to_gate_weights, const float *cell_state,
98 const float *cell_to_gate_weights,
99 const float *layer_norm_coefficients, const float *gate_bias,
100 const int n_batch, const int n_input, const int n_aux_input,
101 const int n_output, const int n_cell,
102 const TfLiteFusedActivation activation, float *gate,
103 const bool is_input_all_zeros, const bool is_aux_input_all_zeros)
104{
105 const bool use_peephole = (cell_to_gate_weights != nullptr);
106 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
107
108 // Initialize scratch buffers with bias for regular lstm or initialize with
109 // zero for layer norm lstm.
110 if (use_layer_norm)
111 {
112 std::fill_n(gate, n_cell * n_batch, 0.0f);
113 }
114 else
115 {
116 tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
117 }
118 // For each batch and cell: compute input_weight * input.
119 // Skip if input is all zeros.
120 if (!is_input_all_zeros)
121 {
122 tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_gate_weights, n_cell, n_input, input,
123 n_batch, gate);
124 }
125 // For each batch and cell: compute aux_input_weight * aux_input.
126 // Skip if auxiliary input is not available or all zeros.
127 if (!is_aux_input_all_zeros)
128 {
129 tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, n_cell,
130 n_aux_input, aux_input, n_batch, gate);
131 }
132 // For each batch and cell: compute recurrent_weight * output_state.
133 tensor_utils::MatrixBatchVectorMultiplyAccumulate(recurrent_to_gate_weights, n_cell, n_output,
134 output_state, n_batch, gate);
135 // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
136 if (use_peephole)
137 {
138 tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_gate_weights, n_cell, cell_state,
139 n_batch, gate);
140 }
141 // Do layer normalization (if layer norm LSTM)
142 if (use_layer_norm)
143 {
144 tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
145 tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, gate, n_batch,
146 gate);
147 tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
148 }
149 // Apply activation
150 tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation, gate);
151}
152
153inline void LstmStepFloat(
154 const float *input_ptr, const float *input_to_input_weights_ptr,
155 const float *input_to_forget_weights_ptr, const float *input_to_cell_weights_ptr,
156 const float *input_to_output_weights_ptr, const float *aux_input_ptr,
157 const float *aux_input_to_input_weights_ptr, const float *aux_input_to_forget_weights_ptr,
158 const float *aux_input_to_cell_weights_ptr, const float *aux_input_to_output_weights_ptr,
159 const float *recurrent_to_input_weights_ptr, const float *recurrent_to_forget_weights_ptr,
160 const float *recurrent_to_cell_weights_ptr, const float *recurrent_to_output_weights_ptr,
161 const float *cell_to_input_weights_ptr, const float *cell_to_forget_weights_ptr,
162 const float *cell_to_output_weights_ptr, const float *input_layer_norm_coefficients_ptr,
163 const float *forget_layer_norm_coefficients_ptr, const float *cell_layer_norm_coefficients_ptr,
164 const float *output_layer_norm_coefficients_ptr, const float *input_gate_bias_ptr,
165 const float *forget_gate_bias_ptr, const float *cell_gate_bias_ptr,
166 const float *output_gate_bias_ptr, const float *projection_weights_ptr,
167 const float *projection_bias_ptr, const TfLiteLSTMParams *params, int n_batch, int n_cell,
168 int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float *output_state_ptr,
169 float *cell_state_ptr, float *scratch0, float *scratch1, float *scratch2, float *scratch3,
170 float *output_ptr)
171{
172 // Since we have already checked that weights are all there or none, we can
173 // check the existence of only one to the get the condition.
174 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
175
176 // Make named scratch buffers.
177 float *input_gate_scratch = scratch0;
178 float *forget_gate_scratch = scratch1;
179 float *cell_gate_scratch = scratch2;
180 float *output_gate_scratch = scratch3;
181
182 // Check if inputs are all zeros so we can skip some computations.
183 const bool is_input_all_zeros = tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
184 const bool is_aux_input_all_zeros =
185 (aux_input_ptr == nullptr || tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
186 if (!use_cifg)
187 {
188 // Calculate the input gate. (If not CIFG.)
189 CalculateLstmGateFloat(input_ptr, input_to_input_weights_ptr, aux_input_ptr,
190 aux_input_to_input_weights_ptr, output_state_ptr,
191 recurrent_to_input_weights_ptr, cell_state_ptr,
192 cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
193 input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
194 /*activation=*/kTfLiteActSigmoid, input_gate_scratch, is_input_all_zeros,
195 is_aux_input_all_zeros);
196 }
197 // Calculate the forget gate.
198 CalculateLstmGateFloat(input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
199 aux_input_to_forget_weights_ptr, output_state_ptr,
200 recurrent_to_forget_weights_ptr, cell_state_ptr,
201 cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
202 forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
203 /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
204 is_aux_input_all_zeros);
205 // Calculate the cell update gate.
206 CalculateLstmGateFloat(
207 input_ptr, input_to_cell_weights_ptr, aux_input_ptr, aux_input_to_cell_weights_ptr,
208 output_state_ptr, recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
209 /*cell_to_gate_weights=*/nullptr, cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, n_batch,
210 n_input, n_aux_input, n_output, n_cell, params->activation, cell_gate_scratch,
211 is_input_all_zeros, is_aux_input_all_zeros);
212 // Update the cell state.
213 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
214 cell_gate_scratch, use_cifg, params->cell_clip);
215 // Calculate output gate.
216 CalculateLstmGateFloat(input_ptr, input_to_output_weights_ptr, aux_input_ptr,
217 aux_input_to_output_weights_ptr, output_state_ptr,
218 recurrent_to_output_weights_ptr, cell_state_ptr,
219 cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
220 output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
221 /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
222 is_aux_input_all_zeros);
223 // Update the output state.
224 CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
225 params->activation, projection_weights_ptr, projection_bias_ptr,
226 params->proj_clip, output_state_ptr, scratch2);
227 // Copy output state to the output. Note that the output's rows may not be
228 // contiguous (output_batch_leading_dim != n_output).
229 for (int b = 0; b < n_batch; b++)
230 {
231 std::copy_n(output_state_ptr + b * n_output, n_output,
232 output_ptr + b * output_batch_leading_dim);
233 }
234}
235
236} // namespace
237
238void EvalFloat(const Tensor *input,
239
240 const Tensor *input_to_input_weights, const Tensor *input_to_forget_weights,
241 const Tensor *input_to_cell_weights, const Tensor *input_to_output_weights,
242
243 const Tensor *recurrent_to_input_weights, const Tensor *recurrent_to_forget_weights,
244 const Tensor *recurrent_to_cell_weights, const Tensor *recurrent_to_output_weights,
245
246 const Tensor *cell_to_input_weights, const Tensor *cell_to_forget_weights,
247 const Tensor *cell_to_output_weights,
248
249 const Tensor *input_layer_norm_coefficients,
250 const Tensor *forget_layer_norm_coefficients,
251 const Tensor *cell_layer_norm_coefficients,
252 const Tensor *output_layer_norm_coefficients,
253
254 const Tensor *aux_input, const Tensor *aux_input_to_input_weights,
255 const Tensor *aux_input_to_forget_weights, const Tensor *aux_input_to_cell_weights,
256 const Tensor *aux_input_to_output_weights,
257
258 const Tensor *input_gate_bias, const Tensor *forget_gate_bias,
259 const Tensor *cell_gate_bias, const Tensor *output_gate_bias,
260
261 const Tensor *projection_weights, const Tensor *projection_bias,
262 const TfLiteLSTMParams *params,
263
264 bool forward_sequence, bool time_major, int output_offset,
265
266 Tensor *scratch_buffer, Tensor *output_state, Tensor *cell_state, Tensor *output)
267{
268 const Shape &input_shape = input->shape();
269 assert(input_shape.num_dims() >= 2 && input_shape.num_dims() <= 3);
270 int max_time, n_batch;
271 if (input_shape.num_dims() == 3)
272 {
273 max_time = (time_major) ? input_shape.dim(0) : input_shape.dim(1);
274 n_batch = (time_major) ? input_shape.dim(1) : input_shape.dim(0);
275 }
276 else
277 {
278 max_time = 1;
279 n_batch = input_shape.dim(0);
280 }
281 const int n_input = input_shape.dim(input_shape.num_dims() - 1);
282
283 int aux_input_temp = 0;
284 if (aux_input)
285 {
286 const Shape &aux_input_shape = aux_input->shape();
287 aux_input_temp = aux_input_shape.dim(aux_input_shape.num_dims() - 1);
288 }
289 const int aux_input_size = aux_input_temp;
290
291 // n_cell and n_output will be the same size when there is no projection.
292 const Shape &input_to_output_weights_shape = input_to_output_weights->shape();
293 const Shape &recurrent_to_output_weights_shape = recurrent_to_output_weights->shape();
294 const int n_cell = input_to_output_weights_shape.dim(0);
295 const int n_output = recurrent_to_output_weights_shape.dim(1);
296
297 // Since we have already checked that weights are all there or none, we can
298 // check the existence of only one to the get the condition.
299 const bool use_cifg = (input_to_input_weights == nullptr);
300
301 // Index the scratch buffers pointers to the global scratch buffer.
302 float *scratch_buffer_ptr = getTensorData<float>(scratch_buffer);
303 float *input_gate_scratch = nullptr;
304 float *cell_gate_scratch = nullptr;
305 float *forget_gate_scratch = nullptr;
306 float *output_gate_scratch = nullptr;
307 if (use_cifg)
308 {
309 cell_gate_scratch = scratch_buffer_ptr;
310 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
311 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
312 }
313 else
314 {
315 input_gate_scratch = scratch_buffer_ptr;
316 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
317 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
318 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
319 }
320
321 const Shape &output_shape = output->shape();
322 const int output_batch_leading_dim = output_shape.dim(output_shape.num_dims() - 1);
323 if (time_major)
324 {
325 // Loop through the sequence.
326 const int input_step = n_batch * n_input;
327 const int output_step = n_batch * output_batch_leading_dim;
328 for (int t = 0; t < max_time; t++)
329 {
330 // If this is the forward_sequence, step forward, otherwise step
331 // backwards.
332 const int t_rel = forward_sequence ? t : max_time - t - 1;
333 const float *input_ptr = getTensorData<float>(input) + t_rel * input_step;
334 const float *aux_input_ptr = nullptr;
335 if (aux_input)
336 {
337 aux_input_ptr = getTensorData<float>(aux_input) + t_rel * input_step;
338 }
339 float *output_ptr = getTensorData<float>(output) + t_rel * output_step + output_offset;
340
341 LstmStepFloat(
342 input_ptr, getTensorData<float>(input_to_input_weights),
343 getTensorData<float>(input_to_forget_weights), getTensorData<float>(input_to_cell_weights),
344 getTensorData<float>(input_to_output_weights), aux_input_ptr,
345 getTensorData<float>(aux_input_to_input_weights),
346 getTensorData<float>(aux_input_to_forget_weights),
347 getTensorData<float>(aux_input_to_cell_weights),
348 getTensorData<float>(aux_input_to_output_weights),
349 getTensorData<float>(recurrent_to_input_weights),
350 getTensorData<float>(recurrent_to_forget_weights),
351 getTensorData<float>(recurrent_to_cell_weights),
352 getTensorData<float>(recurrent_to_output_weights),
353 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
354 getTensorData<float>(cell_to_output_weights),
355 getTensorData<float>(input_layer_norm_coefficients),
356 getTensorData<float>(forget_layer_norm_coefficients),
357 getTensorData<float>(cell_layer_norm_coefficients),
358 getTensorData<float>(output_layer_norm_coefficients), getTensorData<float>(input_gate_bias),
359 getTensorData<float>(forget_gate_bias), getTensorData<float>(cell_gate_bias),
360 getTensorData<float>(output_gate_bias), getTensorData<float>(projection_weights),
361 getTensorData<float>(projection_bias), params, n_batch, n_cell, n_input, aux_input_size,
362 n_output, output_batch_leading_dim, getTensorData<float>(output_state),
363 getTensorData<float>(cell_state), input_gate_scratch, forget_gate_scratch,
364 cell_gate_scratch, output_gate_scratch, output_ptr);
365 }
366 }
367 else
368 {
369 for (int b = 0; b < n_batch; b++)
370 {
371 const int input_step = n_input;
372 const int output_step = output_batch_leading_dim;
373 for (int t = 0; t < max_time; t++)
374 {
375 // If this is the forward_sequence, step forward, otherwise step
376 // backwards.
377 const int t_rel = forward_sequence ? t : max_time - t - 1;
378 const int time_offset = b * max_time + t_rel;
379 const float *input_ptr = getTensorData<float>(input) + time_offset * input_step;
380 const float *aux_input_ptr = nullptr;
381 if (aux_input)
382 {
383 aux_input_ptr = getTensorData<float>(aux_input) + time_offset * input_step;
384 }
385 float *output_ptr =
386 getTensorData<float>(output) + time_offset * output_step + output_offset;
387
388 // Offset the {output,cell}_state pointers to the right batch.
389 float *output_state_ptr = getTensorData<float>(output_state) + b * output_batch_leading_dim;
390 float *cell_state_ptr = getTensorData<float>(cell_state) + b * n_cell;
391 // Offset the scratch pointers to the right batch.
392 float *input_gate_scratch_ptr =
393 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
394 float *forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
395 float *cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
396 float *output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
397
398 LstmStepFloat(
399 input_ptr, getTensorData<float>(input_to_input_weights),
400 getTensorData<float>(input_to_forget_weights),
401 getTensorData<float>(input_to_cell_weights),
402 getTensorData<float>(input_to_output_weights), aux_input_ptr,
403 getTensorData<float>(aux_input_to_input_weights),
404 getTensorData<float>(aux_input_to_forget_weights),
405 getTensorData<float>(aux_input_to_cell_weights),
406 getTensorData<float>(aux_input_to_output_weights),
407 getTensorData<float>(recurrent_to_input_weights),
408 getTensorData<float>(recurrent_to_forget_weights),
409 getTensorData<float>(recurrent_to_cell_weights),
410 getTensorData<float>(recurrent_to_output_weights),
411 getTensorData<float>(cell_to_input_weights), getTensorData<float>(cell_to_forget_weights),
412 getTensorData<float>(cell_to_output_weights),
413 getTensorData<float>(input_layer_norm_coefficients),
414 getTensorData<float>(forget_layer_norm_coefficients),
415 getTensorData<float>(cell_layer_norm_coefficients),
416 getTensorData<float>(output_layer_norm_coefficients),
417 getTensorData<float>(input_gate_bias), getTensorData<float>(forget_gate_bias),
418 getTensorData<float>(cell_gate_bias), getTensorData<float>(output_gate_bias),
419 getTensorData<float>(projection_weights), getTensorData<float>(projection_bias), params,
420 /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
421 output_state_ptr, cell_state_ptr, input_gate_scratch_ptr, forget_gate_scratch_ptr,
422 cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);
423 }
424 }
425 }
426}
427
428} // namespace lstm
429} // namespace kernels
430} // namespace luci_interpreter
431
432namespace luci_interpreter
433{
434namespace kernels
435{
436
438 const Tensor *input,
439
440 const Tensor *input_to_input_weights, const Tensor *input_to_forget_weights,
441 const Tensor *input_to_cell_weights, const Tensor *input_to_output_weights,
442
443 const Tensor *recurrent_to_input_weights, const Tensor *recurrent_to_forget_weights,
444 const Tensor *recurrent_to_cell_weights, const Tensor *recurrent_to_output_weights,
445
446 const Tensor *cell_to_input_weights, const Tensor *cell_to_forget_weights,
447 const Tensor *cell_to_output_weights,
448
449 const Tensor *input_gate_bias, const Tensor *forget_gate_bias, const Tensor *cell_gate_bias,
450 const Tensor *output_gate_bias,
451
452 const Tensor *projection_weights, const Tensor *projection_bias,
453
454 const Tensor *output_state, const Tensor *cell_state, const Tensor *input_layer_norm_coefficients,
455 const Tensor *forget_layer_norm_coefficients, const Tensor *cell_layer_norm_coefficients,
456 const Tensor *output_layer_norm_coefficients,
457
458 Tensor *output, Tensor *scratchpad_1, Tensor *scratchpad_2, Tensor *scratchpad_3,
461 {input,
466
471
475
480
483
486
491 {output, scratchpad_1, scratchpad_2, scratchpad_3}, params)
492{
493 // Do nothing
494}
495
496// Check that input tensor dimensions matches with each other.
497void UnidirectionalSequenceLSTM::check_input_tensor_dimensions(int n_input, int n_output,
498 int n_cell, bool use_layer_norm,
499 bool is_integer)
500{
501 // Making sure clipping parameters have valid values.
502 // == 0 means no clipping
503 // > 0 means clipping
504 LUCI_INTERPRETER_CHECK(params().cell_clip >= 0);
505 LUCI_INTERPRETER_CHECK(params().proj_clip >= 0);
506
507 if (input_to_input_weights() != nullptr)
508 {
509 const Shape &input_to_input_weights_shape = input_to_input_weights()->shape();
510 LUCI_INTERPRETER_CHECK(input_to_input_weights_shape.num_dims() == 2);
511 LUCI_INTERPRETER_CHECK(input_to_input_weights_shape.dim(0) == n_cell);
512 LUCI_INTERPRETER_CHECK(input_to_input_weights_shape.dim(1) == n_input);
513 }
514
515 const Shape &input_to_forget_weights_shape = input_to_forget_weights()->shape();
516 LUCI_INTERPRETER_CHECK(input_to_forget_weights_shape.num_dims() == 2);
517 LUCI_INTERPRETER_CHECK(input_to_forget_weights_shape.dim(0) == n_cell);
518 LUCI_INTERPRETER_CHECK(input_to_forget_weights_shape.dim(1) == n_input);
519
520 const Shape &input_to_cell_weights_shape = input_to_cell_weights()->shape();
521 LUCI_INTERPRETER_CHECK(input_to_cell_weights_shape.num_dims() == 2);
522 LUCI_INTERPRETER_CHECK(input_to_cell_weights_shape.dim(0) == n_cell);
523 LUCI_INTERPRETER_CHECK(input_to_cell_weights_shape.dim(1) == n_input);
524
525 if (recurrent_to_input_weights() != nullptr)
526 {
527 const Shape &recurrent_to_input_weights_shape = recurrent_to_input_weights()->shape();
528 LUCI_INTERPRETER_CHECK(recurrent_to_input_weights_shape.num_dims() == 2);
529 LUCI_INTERPRETER_CHECK(recurrent_to_input_weights_shape.dim(0) == n_cell);
530 LUCI_INTERPRETER_CHECK(recurrent_to_input_weights_shape.dim(1) == n_output);
531 }
532
533 const Shape &recurrent_to_forget_weights_shape = recurrent_to_forget_weights()->shape();
534 LUCI_INTERPRETER_CHECK(recurrent_to_forget_weights_shape.num_dims() == 2);
535 LUCI_INTERPRETER_CHECK(recurrent_to_forget_weights_shape.dim(0) == n_cell);
536 LUCI_INTERPRETER_CHECK(recurrent_to_forget_weights_shape.dim(1) == n_output);
537
538 const Shape &recurrent_to_cell_weights_shape = recurrent_to_cell_weights()->shape();
539 LUCI_INTERPRETER_CHECK(recurrent_to_cell_weights_shape.num_dims() == 2);
540 LUCI_INTERPRETER_CHECK(recurrent_to_cell_weights_shape.dim(0) == n_cell);
541 LUCI_INTERPRETER_CHECK(recurrent_to_cell_weights_shape.dim(1) == n_output);
542
543 // We make sure the input-gate's parameters are either both present (regular
544 // LSTM) or not at all (CIFG-LSTM).
545 const bool cifg_weights_all_or_none =
546 ((input_to_input_weights() != nullptr) && (recurrent_to_input_weights() != nullptr)) ||
547 ((input_to_input_weights() == nullptr) && (recurrent_to_input_weights() == nullptr));
548 LUCI_INTERPRETER_CHECK(cifg_weights_all_or_none == true);
549
550 if (cell_to_input_weights() != nullptr)
551 {
552 const Shape &cell_to_input_weights_shape = cell_to_input_weights()->shape();
553 LUCI_INTERPRETER_CHECK(cell_to_input_weights_shape.num_dims() == 1);
554 LUCI_INTERPRETER_CHECK(cell_to_input_weights_shape.dim(0) == n_cell);
555 LUCI_INTERPRETER_CHECK(is_integer
556 ? cell_to_input_weights()->element_type() == loco::DataType::S16
557 : cell_to_input_weights()->element_type() ==
558 input_to_forget_weights()->element_type());
559 }
560
561 if (cell_to_forget_weights() != nullptr)
562 {
563 const Shape &cell_to_forget_weights_shape = cell_to_forget_weights()->shape();
564 LUCI_INTERPRETER_CHECK(cell_to_forget_weights_shape.num_dims() == 1);
565 LUCI_INTERPRETER_CHECK(cell_to_forget_weights_shape.dim(0) == n_cell);
566 LUCI_INTERPRETER_CHECK(is_integer
567 ? cell_to_forget_weights()->element_type() == loco::DataType::S16
568 : cell_to_forget_weights()->element_type() ==
569 input_to_forget_weights()->element_type());
570 }
571
572 if (cell_to_output_weights() != nullptr)
573 {
574 const Shape &cell_to_output_weights_shape = cell_to_output_weights()->shape();
575 LUCI_INTERPRETER_CHECK(cell_to_output_weights_shape.num_dims() == 1);
576 LUCI_INTERPRETER_CHECK(cell_to_output_weights_shape.dim(0) == n_cell);
577 LUCI_INTERPRETER_CHECK(is_integer
578 ? cell_to_output_weights()->element_type() == loco::DataType::S16
579 : cell_to_output_weights()->element_type() ==
580 input_to_forget_weights()->element_type());
581 }
582
583 // Making sure the peephole weights are there all or none.
584 const bool use_cifg = (input_to_input_weights() == nullptr);
585 const bool peephole_weights_all_or_none =
586 ((cell_to_input_weights() != nullptr || use_cifg) && (cell_to_forget_weights() != nullptr) &&
587 (cell_to_output_weights() != nullptr)) ||
588 ((cell_to_input_weights() == nullptr) && (cell_to_forget_weights() == nullptr) &&
589 (cell_to_output_weights() == nullptr));
590 LUCI_INTERPRETER_CHECK(peephole_weights_all_or_none == true);
591
592 // Make sure the input gate bias is present only when not a CIFG-LSTM.
593 if (use_cifg)
594 {
596 }
597 else
598 {
599 const Shape &input_gate_bias_shape = input_gate_bias()->shape();
600 LUCI_INTERPRETER_CHECK(input_gate_bias_shape.num_dims() == 1);
601 LUCI_INTERPRETER_CHECK(input_gate_bias_shape.dim(0) == n_cell);
602 if (is_integer)
603 {
604 LUCI_INTERPRETER_CHECK(input_gate_bias()->element_type() == loco::DataType::S32);
605 }
606 else
607 {
608 LUCI_INTERPRETER_CHECK(input_gate_bias()->element_type() == loco::DataType::FLOAT32);
609 }
610 }
611
612 const Shape &forget_gate_bias_shape = forget_gate_bias()->shape();
613 LUCI_INTERPRETER_CHECK(forget_gate_bias_shape.num_dims() == 1);
614 LUCI_INTERPRETER_CHECK(forget_gate_bias_shape.dim(0) == n_cell);
615 if (is_integer)
616 {
617 LUCI_INTERPRETER_CHECK(forget_gate_bias()->element_type() == loco::DataType::S32);
618 }
619 else
620 {
621 LUCI_INTERPRETER_CHECK(forget_gate_bias()->element_type() == loco::DataType::FLOAT32);
622 }
623
624 const Shape &cell_gate_bias_shape = cell_gate_bias()->shape();
625 LUCI_INTERPRETER_CHECK(cell_gate_bias_shape.num_dims() == 1);
626 LUCI_INTERPRETER_CHECK(cell_gate_bias_shape.dim(0) == n_cell);
627 if (is_integer)
628 {
629 LUCI_INTERPRETER_CHECK(cell_gate_bias()->element_type() == loco::DataType::S32);
630 }
631 else
632 {
633 LUCI_INTERPRETER_CHECK(cell_gate_bias()->element_type() == loco::DataType::FLOAT32);
634 }
635
636 const Shape &output_gate_bias_shape = output_gate_bias()->shape();
637 LUCI_INTERPRETER_CHECK(output_gate_bias_shape.num_dims() == 1);
638 LUCI_INTERPRETER_CHECK(output_gate_bias_shape.dim(0) == n_cell);
639 if (is_integer)
640 {
641 LUCI_INTERPRETER_CHECK(output_gate_bias()->element_type() == loco::DataType::S32);
642 }
643 else
644 {
645 LUCI_INTERPRETER_CHECK(output_gate_bias()->element_type() == loco::DataType::FLOAT32);
646 }
647
648 if (projection_weights() != nullptr)
649 {
650 const Shape &projection_weights_shape = projection_weights()->shape();
651 LUCI_INTERPRETER_CHECK(projection_weights_shape.num_dims() == 2);
652 LUCI_INTERPRETER_CHECK(projection_weights_shape.dim(0) == n_output);
653 LUCI_INTERPRETER_CHECK(projection_weights_shape.dim(1) == n_cell);
654 }
655
656 if (projection_bias() != nullptr)
657 {
658 const Shape &projection_bias_shape = projection_bias()->shape();
659 LUCI_INTERPRETER_CHECK(projection_bias_shape.num_dims() == 1);
660 LUCI_INTERPRETER_CHECK(projection_bias_shape.dim(0) == n_output);
661 if (is_integer)
662 {
663 LUCI_INTERPRETER_CHECK(projection_bias()->element_type() == loco::DataType::S32);
664 }
665 else
666 {
667 LUCI_INTERPRETER_CHECK(projection_bias()->element_type() == loco::DataType::FLOAT32);
668 }
669 }
670
671 // Making sure the projection tensors are consistent:
672 // 1) If projection weight is not present, then projection bias should not be
673 // present.
674 // 2) If projection weight is present, then projection bias is optional.
675 // TODO(ghodrat): make sure this is correct.
676 const bool projecton_tensors_consistent =
677 ((projection_weights() != nullptr) || (projection_bias() == nullptr));
678 LUCI_INTERPRETER_CHECK(projecton_tensors_consistent == true);
679
680 if (use_layer_norm)
681 {
682 if (use_cifg)
683 {
685 }
686 else
687 {
689
690 const Shape &input_layer_norm_coefficients_shape = input_layer_norm_coefficients()->shape();
691 LUCI_INTERPRETER_CHECK(input_layer_norm_coefficients_shape.num_dims() == 1);
692 LUCI_INTERPRETER_CHECK(input_layer_norm_coefficients_shape.dim(0) == n_cell);
693 if (is_integer)
694 {
696 loco::DataType::S16);
697 }
698 else
699 {
701 loco::DataType::FLOAT32);
702 }
703 }
704
705 const Shape &forget_layer_norm_coefficients_shape = forget_layer_norm_coefficients()->shape();
706 LUCI_INTERPRETER_CHECK(forget_layer_norm_coefficients_shape.num_dims() == 1);
707 LUCI_INTERPRETER_CHECK(forget_layer_norm_coefficients_shape.dim(0) == n_cell);
708 if (is_integer)
709 {
711 loco::DataType::S16);
712 }
713 else
714 {
716 loco::DataType::FLOAT32);
717 }
718
719 const Shape &cell_layer_norm_coefficients_shape = cell_layer_norm_coefficients()->shape();
720 LUCI_INTERPRETER_CHECK(cell_layer_norm_coefficients_shape.num_dims() == 1);
721 LUCI_INTERPRETER_CHECK(cell_layer_norm_coefficients_shape.dim(0) == n_cell);
722 if (is_integer)
723 {
724 LUCI_INTERPRETER_CHECK(cell_layer_norm_coefficients()->element_type() == loco::DataType::S16);
725 }
726 else
727 {
729 loco::DataType::FLOAT32);
730 }
731
732 const Shape &output_layer_norm_coefficients_shape = output_layer_norm_coefficients()->shape();
733 LUCI_INTERPRETER_CHECK(output_layer_norm_coefficients_shape.num_dims() == 1);
734 LUCI_INTERPRETER_CHECK(output_layer_norm_coefficients_shape.dim(0) == n_cell);
735 if (is_integer)
736 {
738 loco::DataType::S16);
739 }
740 else
741 {
743 loco::DataType::FLOAT32);
744 }
745 }
746}
747
749{
752
753 // TODO support U8
754 LUCI_INTERPRETER_CHECK(input()->element_type() == loco::DataType::FLOAT32);
755 const bool is_integer = false;
756 const bool use_layer_norm = (forget_layer_norm_coefficients() != nullptr);
757
758 // Inferring batch size, number of outputs and sequence length and
759 // number of cells from the input tensors.
760 const Shape &input_shape = input()->shape();
761 LUCI_INTERPRETER_CHECK(input_shape.num_dims() > 1);
762 const bool time_major = params().time_major;
763 const int n_batch = time_major ? input_shape.dim(1) : input_shape.dim(0);
764 // NOTE as dim(2) is accessed, we need to check this is valid
765 LUCI_INTERPRETER_CHECK(input_shape.num_dims() > 2);
766 const int n_input = input_shape.dim(2);
767
768 const Shape &input_to_output_weights_shape = input_to_output_weights()->shape();
769 const int n_cell = input_to_output_weights_shape.dim(0);
770 LUCI_INTERPRETER_CHECK(input_to_output_weights_shape.num_dims() == 2);
771 LUCI_INTERPRETER_CHECK(input_to_output_weights_shape.dim(1) == n_input);
772
773 const Shape &recurrent_to_output_weights_shape = recurrent_to_output_weights()->shape();
774 LUCI_INTERPRETER_CHECK(recurrent_to_output_weights_shape.num_dims() == 2);
775 LUCI_INTERPRETER_CHECK(recurrent_to_output_weights_shape.dim(0) == n_cell);
776
777 const int n_output = recurrent_to_output_weights_shape.dim(1);
778
779 // Check that input tensor dimensions matches with each other.
780 check_input_tensor_dimensions(n_input, n_output, n_cell, use_layer_norm, is_integer);
781
782 // Check the shape of input state tensors.
783 // These tensor may be 1D or 2D. It's fine as long as the total size is
784 // correct.
785 const Shape &output_state_shape = output_state()->shape();
786 const Shape &cell_state_shape = cell_state()->shape();
787 LUCI_INTERPRETER_CHECK(output_state_shape.num_elements() == n_batch * n_output);
788 LUCI_INTERPRETER_CHECK(cell_state_shape.num_elements() == n_batch * n_cell);
789
790 // Resize the output tensors.
791 Shape output_shape = Shape(input_shape.num_dims());
792 for (int i = 0; i < input_shape.num_dims() - 1; i++)
793 {
794 output_shape.dim(i) = input_shape.dim(i);
795 }
796 output_shape.dim(input_shape.num_dims() - 1) = n_output;
798
799 // TODO import integer
800
801 // output_state and cell_state are variable tensor; use scratchpad.
802 getOutputTensors()[1]->resize(output_state_shape);
803 getOutputTensors()[2]->resize(cell_state_shape);
804
805 const bool use_cifg = (input_to_input_weights() == nullptr);
806 if (use_cifg)
807 getOutputTensors()[3]->resize({n_batch, n_cell * 3});
808 else
809 getOutputTensors()[3]->resize({n_batch, n_cell * 4});
810
811 // hybrid not supported
812 if (input_to_output_weights()->element_type() == loco::DataType::U8 &&
813 input()->element_type() == loco::DataType::FLOAT32)
814 {
815 throw std::runtime_error("Hybrid type is not currently supported");
816 }
817 // TODO support hybrid
818 // TODO support U8
819}
820
822{
823 switch (input()->element_type())
824 {
825 case loco::DataType::FLOAT32:
826 evalFloat();
827 break;
828 default:
829 throw std::runtime_error("Unsupported type");
830 }
831}
832
833void UnidirectionalSequenceLSTM::evalFloat() const
834{
835 const bool time_major = params().time_major;
836 const bool use_layer_norm = (forget_layer_norm_coefficients() != nullptr);
837
838 const Tensor *t_input_layer_norm_coefficients =
839 use_layer_norm ? input_layer_norm_coefficients() : nullptr;
840 const Tensor *t_forget_layer_norm_coefficients =
841 use_layer_norm ? forget_layer_norm_coefficients() : nullptr;
842 const Tensor *t_cell_layer_norm_coefficients =
843 use_layer_norm ? cell_layer_norm_coefficients() : nullptr;
844 const Tensor *t_output_layer_norm_coefficients =
845 use_layer_norm ? output_layer_norm_coefficients() : nullptr;
846
847 Tensor *sp_output_state = getOutputTensors()[1];
848 Tensor *sp_cell_state = getOutputTensors()[2];
849 Tensor *sp_scratch_buffer = getOutputTensors()[3];
850
851 // Note: it is expected that output_state input variable tensor reset to zero,
852 // also expected that this variable tensor doesn't have buffer
853 auto scratchpad_data = getTensorData<float>(sp_output_state);
854 std::fill_n(scratchpad_data, sp_output_state->shape().num_elements(), 0);
855 scratchpad_data = getTensorData<float>(sp_cell_state);
856 std::fill_n(scratchpad_data, sp_cell_state->shape().num_elements(), 0);
857 scratchpad_data = getTensorData<float>(sp_scratch_buffer);
858 std::fill_n(scratchpad_data, sp_scratch_buffer->shape().num_elements(), 0);
859
860 TfLiteLSTMParams lstm_params{};
861 lstm_params.activation = getTfLiteActivation(params().activation);
862 lstm_params.cell_clip = params().cell_clip;
863 lstm_params.proj_clip = params().proj_clip;
864 lstm_params.asymmetric_quantize_inputs = params().asymmetric_quantize_inputs;
865
868
871
873
874 t_input_layer_norm_coefficients, t_forget_layer_norm_coefficients,
875 t_cell_layer_norm_coefficients, t_output_layer_norm_coefficients,
876 /*aux_input=*/nullptr,
877 /*aux_input_to_input_weights=*/nullptr,
878 /*aux_input_to_forget_weights=*/nullptr,
879 /*aux_input_to_cell_weights=*/nullptr,
880 /*aux_input_to_output_weights=*/nullptr, input_gate_bias(), forget_gate_bias(),
882
883 projection_weights(), projection_bias(), &lstm_params,
884 /*forward_sequence=*/true, time_major,
885 /*output_offset=*/0, sp_scratch_buffer, sp_output_state, sp_cell_state, output());
886}
887
888} // namespace kernels
889} // namespace luci_interpreter
const std::vector< Tensor * > & getOutputTensors() const
Definition Kernel.h:40
const std::vector< const Tensor * > & getInputTensors() const
Definition Kernel.h:39
const UnidirectionalSequenceLSTMParams & params() const
Definition Kernel.h:67
int32_t dim(int i) const
Definition Tensor.h:41
int32_t num_elements() const
Definition Tensor.h:53
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
UnidirectionalSequenceLSTM(const Tensor *input, const Tensor *input_to_input_weights, const Tensor *input_to_forget_weights, const Tensor *input_to_cell_weights, const Tensor *input_to_output_weights, const Tensor *recurrent_to_input_weights, const Tensor *recurrent_to_forget_weights, const Tensor *recurrent_to_cell_weights, const Tensor *recurrent_to_output_weights, const Tensor *cell_to_input_weights, const Tensor *cell_to_forget_weights, const Tensor *cell_to_output_weights, const Tensor *input_gate_bias, const Tensor *forget_gate_bias, const Tensor *cell_gate_bias, const Tensor *output_gate_bias, const Tensor *projection_weights, const Tensor *projection_bias, const Tensor *output_state, const Tensor *cell_state, const Tensor *input_layer_norm_coefficients, const Tensor *forget_layer_norm_coefficients, const Tensor *cell_layer_norm_coefficients, const Tensor *output_layer_norm_coefficients, Tensor *output, Tensor *scratchpad_1, Tensor *scratchpad_2, Tensor *scratchpad_3, const UnidirectionalSequenceLSTMParams &params)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
void EvalFloat(const Tensor *input, const Tensor *input_to_input_weights, const Tensor *input_to_forget_weights, const Tensor *input_to_cell_weights, const Tensor *input_to_output_weights, const Tensor *recurrent_to_input_weights, const Tensor *recurrent_to_forget_weights, const Tensor *recurrent_to_cell_weights, const Tensor *recurrent_to_output_weights, const Tensor *cell_to_input_weights, const Tensor *cell_to_forget_weights, const Tensor *cell_to_output_weights, const Tensor *input_layer_norm_coefficients, const Tensor *forget_layer_norm_coefficients, const Tensor *cell_layer_norm_coefficients, const Tensor *output_layer_norm_coefficients, const Tensor *aux_input, const Tensor *aux_input_to_input_weights, const Tensor *aux_input_to_forget_weights, const Tensor *aux_input_to_cell_weights, const Tensor *aux_input_to_output_weights, const Tensor *input_gate_bias, const Tensor *forget_gate_bias, const Tensor *cell_gate_bias, const Tensor *output_gate_bias, const Tensor *projection_weights, const Tensor *projection_bias, const TfLiteLSTMParams *params, bool forward_sequence, bool time_major, int output_offset, Tensor *scratch_buffer, Tensor *output_state, Tensor *cell_state, Tensor *output)
TfLiteFusedActivation getTfLiteActivation(Activation activation)
Definition Utils.cpp:30
int32_t size[5]
Definition Slice.cpp:35
Definition Shape.h:28