ONE - On-device Neural Engine
Loading...
Searching...
No Matches
LSTM.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_UNIDIRECTIONALSEQUENCELSTM_H__
19#define __NNFW_CKER_UNIDIRECTIONALSEQUENCELSTM_H__
20
21#include "cker/TensorUtils.h"
22#include "cker/Types.h"
23
24namespace nnfw
25{
26namespace cker
27{
28
29// LINT.IfChange
30// Calculates a single LSTM gate.
31//
32// Implements the following formula: (* is matrix multiply)
33// gate = activate(W_input * input + W_aux * aux_input +
34// W_peephole * cell + W_recurrent * prev_output + bias)
35// with layer norm:
36// gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
37//
38// Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
39//
40// Parameters:
41// Input vectors (to LSTM): | Size: | Optional?
42// input | n_input |
43// aux_input | n_aux_input | y (bidir LSTM)
44// Input vectors (persistent states):
45// output_state | n_output |
46// cell_state | n_cell |
47// 'Constant' inputs:
48// input_to_gate_weights | n_cell * n_input |
49// aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
50// recurrent_to_gate_weights | n_cell * n_output |
51// cell_to_gate_weights | n_cell | y (peephole)
52// gate_bias | n_cell |
53// layer_norm_coefficients | n_cell | y (layer norm)
54// Output vector:
55// gate | n_cell |
56// Scalar parameters:
57// n_batch - batch size / number of vectors
58// n_input, n_aux_input, n_output, n_cell - size of vectors.
59// activation - activation to use.
60// is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
61// use_layer_norm - if doing layer norm LSTM.
62inline void CalculateLstmGateFloat(const float *input, const float *input_to_gate_weights,
63 const float *aux_input, const float *aux_input_to_gate_weights,
64 const float *output_state,
65 const float *recurrent_to_gate_weights, const float *cell_state,
66 const float *cell_to_gate_weights,
67 const float *layer_norm_coefficients, const float *gate_bias,
68 const int n_batch, const int n_input, const int n_aux_input,
69 const int n_output, const int n_cell,
70 const FusedActivationFunctionType activation, float *gate,
71 const bool is_input_all_zeros, const bool is_aux_input_all_zeros)
72{
73 const bool use_peephole = (cell_to_gate_weights != nullptr);
74 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
75
76 // Initialize scratch buffers with bias for regular lstm or initialize with
77 // zero for layer norm lstm.
78 if (use_layer_norm)
79 {
80 std::fill_n(gate, n_cell * n_batch, 0.0f);
81 }
82 else
83 {
84 VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
85 }
86 // For each batch and cell: compute input_weight * input.
87 // Skip if input is all zeros.
88 if (!is_input_all_zeros)
89 {
90 MatrixBatchVectorMultiplyAccumulate(input_to_gate_weights, n_cell, n_input, input, n_batch,
91 gate, /*result_stride=*/1);
92 }
93 // For each batch and cell: compute aux_input_weight * aux_input.
94 // Skip if auxiliary input is not available or all zeros.
95 if (!is_aux_input_all_zeros)
96 {
97 MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
98 n_batch, gate, /*result_stride=*/1);
99 }
100 // For each batch and cell: compute recurrent_weight * output_state.
101 MatrixBatchVectorMultiplyAccumulate(recurrent_to_gate_weights, n_cell, n_output, output_state,
102 n_batch, gate, /*result_stride=*/1);
103 // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
104 if (use_peephole)
105 {
106 VectorBatchVectorCwiseProductAccumulate(cell_to_gate_weights, n_cell, cell_state, n_batch,
107 gate);
108 }
109 // Do layer normalization (if layer norm LSTM)
110 if (use_layer_norm)
111 {
112 MeanStddevNormalization(gate, gate, n_cell, n_batch);
113 VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, gate, n_batch, gate);
114 VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
115 }
116 // Apply activation
117 ApplyActivationToVector(gate, n_batch * n_cell, activation, gate);
118}
119
120// Updates the LSTM cell state, used by both float and hybrid LSTM versions.
121//
122// Implements the following formula:
123// cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
124//
125// With CIFG LSTM, input gate is replaced by (1-forget_gate).
126//
127// Parameters:
128// - n_batch, n_cell: sizes of vectors
129// - cell_state: input/output vector, size n_batch*n_cell
130// - input_gate: input vector, size n_batch*n_cell.
131// - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
132// - cell_gate: input vector, size n_batch*n_cell.
133// - use_cifg: use 1-forget_gate instead of input_gate.
134// - clip: if > 0, clip the resulting cell state to [-clip, +clip].
135void UpdateLstmCellFloat(int n_batch, int n_cell, float *cell_state, const float *input_gate,
136 float *forget_gate, const float *cell_gate, bool use_cifg, float clip)
137{
138 // Define variable for 4th argument to avoid warning
139 // Compiler warning: passing argument 4 to restrict-qualified parameter aliases with argument 2
140 const float *cwise_product_rhs = cell_state;
141 VectorVectorCwiseProduct(forget_gate, cwise_product_rhs, n_batch * n_cell, cell_state);
142
143 if (use_cifg)
144 {
145 // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
146 // scratch, as input_gate array is not allocated in this case. (Be careful
147 // not to write to the scratch before reading the forget gate data.)
148 float *scratch = forget_gate;
149 Sub1Vector(forget_gate, n_batch * n_cell, scratch);
150 VectorVectorCwiseProductAccumulate(cell_gate, scratch, n_batch * n_cell, cell_state);
151 }
152 else
153 {
154 VectorVectorCwiseProductAccumulate(cell_gate, input_gate, n_batch * n_cell, cell_state);
155 }
156 if (clip > 0.0f)
157 {
158 CwiseClipping(cell_state, n_batch * n_cell, clip);
159 }
160}
161
162// Calculates the output state tensor of an LSTM step.
163//
164// Implements the following formula:
165// output_no_projection = output_gate .* activate(cell_state)
166// (elementwise vector product)
167// If no projection is used:
168// output = output_state = output_no_projection
169// With projection:
170// output = output_state = clip(W*output_no_projection + bias)
171//
172// Output might not have a different 'stride' than n_batch, so we need to copy.
173//
174// Parameters:
175// - n_batch: batches: the number of distinct vectors in each array.
176// - n_cell, n_output: sizes of vectors.
177// - cell_state, output_gate: input vectors, size n_batch*n_cell.
178// - projection_weights, projection_weights_scale, projection_bias:
179// constant inputs, describing projection matrix and bias.
180// - proj_clip: if > 0, clip the output of the projection.
181// - output_state: output vector, size n_batch*n_output. Must be contigous.
182// - scratch: scratch area, size n_batch*n_cell.
183void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output, const float *cell_state,
184 const float *output_gate, FusedActivationFunctionType activation,
185 const float *projection_weights, const float *projection_bias,
186 const float proj_clip, float *output_state, float *scratch)
187{
188 ApplyActivationToVector(cell_state, n_batch * n_cell, activation, scratch);
189
190 // Define variable for 4th argument to avoid warning
191 // Compiler warning: passing argument 4 to restrict-qualified parameter aliases with argument 2
192 const float *cwise_product_rhs = scratch;
193 VectorVectorCwiseProduct(output_gate, cwise_product_rhs, n_batch * n_cell, scratch);
194
195 const bool use_projection = (projection_weights != nullptr);
196 const bool use_projection_bias = (projection_bias != nullptr);
197
198 if (use_projection)
199 {
200 if (use_projection_bias)
201 {
202 VectorBatchVectorAssign(projection_bias, n_output, n_batch, output_state);
203 }
204 else
205 {
206 std::fill_n(output_state, n_batch * n_output, 0.0f);
207 }
208 MatrixBatchVectorMultiplyAccumulate(projection_weights, n_output, n_cell, scratch, n_batch,
209 output_state, /*result_stride=*/1);
210 if (proj_clip > 0.0f)
211 {
212 CwiseClipping(output_state, n_batch * n_output, proj_clip);
213 }
214 }
215 else
216 {
217 std::copy_n(scratch, n_batch * n_output, output_state);
218 }
219}
220
221// Performs an LSTM batch inference step for input specified by input_ptr.
222// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
223// biases (*_bias_ptr), and buffers (*_scratch), along with additional
224// parameters:
225// - params: various LSTM params including activation, clipping, etc.,
226// - n_batch: size of batch,
227// - n_cell: number of cells (or units),
228// - n_input: the input size,
229// - n_aux_input: the auxiliary input size.
230// - n_output: the output size.
231// - output_batch_leading_dim: the leading dimension of the output buffer.
232//
233// Input of size 'n_batch * n_input':
234// input_ptr
235// Input of size 'n_batch * n_aux_input':
236// aux_input_ptr - optional (can be nullptr)
237//
238// LSTM weights:
239// Input weights of size 'n_cell * n_input':
240// input_to_input_weights - optional
241// input_to_forget_weights
242// input_to_cell_weights
243// input_to_output_weights
244// Auxiliary input weights of size 'n_cell * n_aux_input':
245// aux_input_to_input_weights - optional
246// aux_input_to_forget_weights - optional
247// aux_input_to_cell_weights - optional
248// aux_input_to_output_weights - optional
249// Recurrent weights of size 'n_cell * n_output':
250// recurrent_to_input_weights - optional
251// recurrent_to_forget_weights
252// recurrent_to_cell_weights
253// recurrent_to_input_weights
254// Peephole weights of size 'n_cell', representing diagonal matrices.
255// cell_to_input_weights - optional
256// cell_to_cell_weights - optional
257// cell_to_output_weights - optional
258// Projection weights of size 'n_output * n_cell'
259// projection_weights_ptr - optional
260// Gate biases of size 'n_cell':
261// input_gate_bias_ptr - optional
262// forget_gate_bias_ptr
263// cell_gate_bias_ptr
264// output_gate_bias_ptr
265//
266// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
267// input_layer_norm_coefficients_ptr - optional
268// forget_layer_norm_coefficients_ptr - optional
269// cell_layer_norm_coefficients_ptr - optional
270// output_layer_norm_coefficients_ptr - optional
271//
272// The pointers to the cell and output state and the output are updated.
273//
274// The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
275// in batch_major order, and each step processes batch_size many inputs from
276// input_ptr, and updates batch_size many cell and output states.
277//
278// The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
279// output tensor, and in most cases will be equal to n_output. It is usually not
280// when we want to store the LSTM output into a slice of the output tensor, e.g.
281// for bidirectional LSTMs with merge_outputs. In this case, the batched
282// operations cannot be used since they assume that the batched outputs are
283// contiguous, and we manually loop over the batched outputs.
284// LINT.IfChange
285inline void LstmStepFloat(
286 const float *input_ptr, const float *input_to_input_weights_ptr,
287 const float *input_to_forget_weights_ptr, const float *input_to_cell_weights_ptr,
288 const float *input_to_output_weights_ptr, const float *aux_input_ptr,
289 const float *aux_input_to_input_weights_ptr, const float *aux_input_to_forget_weights_ptr,
290 const float *aux_input_to_cell_weights_ptr, const float *aux_input_to_output_weights_ptr,
291 const float *recurrent_to_input_weights_ptr, const float *recurrent_to_forget_weights_ptr,
292 const float *recurrent_to_cell_weights_ptr, const float *recurrent_to_output_weights_ptr,
293 const float *cell_to_input_weights_ptr, const float *cell_to_forget_weights_ptr,
294 const float *cell_to_output_weights_ptr, const float *input_layer_norm_coefficients_ptr,
295 const float *forget_layer_norm_coefficients_ptr, const float *cell_layer_norm_coefficients_ptr,
296 const float *output_layer_norm_coefficients_ptr, const float *input_gate_bias_ptr,
297 const float *forget_gate_bias_ptr, const float *cell_gate_bias_ptr,
298 const float *output_gate_bias_ptr, const float *projection_weights_ptr,
299 const float *projection_bias_ptr, const LSTMParams *params, int n_batch, int n_cell, int n_input,
300 int n_aux_input, int n_output, int output_batch_leading_dim, float *output_state_ptr,
301 float *cell_state_ptr, float *scratch0, float *scratch1, float *scratch2, float *scratch3,
302 float *output_ptr)
303{
304 // Since we have already checked that weights are all there or none, we can
305 // check the existence of only one to the get the condition.
306 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
307
308 // Make named scratch buffers.
309 float *input_gate_scratch = scratch0;
310 float *forget_gate_scratch = scratch1;
311 float *cell_gate_scratch = scratch2;
312 float *output_gate_scratch = scratch3;
313
314 // Check if inputs are all zeros so we can skip some computations.
315 const bool is_input_all_zeros = IsZeroVector(input_ptr, n_batch * n_input);
316 const bool is_aux_input_all_zeros =
317 (aux_input_ptr == nullptr || IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
318 if (!use_cifg)
319 {
320 // Calculate the input gate. (If not CIFG.)
321 CalculateLstmGateFloat(input_ptr, input_to_input_weights_ptr, aux_input_ptr,
322 aux_input_to_input_weights_ptr, output_state_ptr,
323 recurrent_to_input_weights_ptr, cell_state_ptr,
324 cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
325 input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
326 /*activation=kTfLiteActSigmoid*/ FusedActivationFunctionType::kSigmoid,
327 input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
328 }
329 // Calculate the forget gate.
330 CalculateLstmGateFloat(input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
331 aux_input_to_forget_weights_ptr, output_state_ptr,
332 recurrent_to_forget_weights_ptr, cell_state_ptr,
333 cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
334 forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
335 /*activation=kTfLiteActSigmoid*/ FusedActivationFunctionType::kSigmoid,
336 forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
337 // Calculate the cell update gate.
339 input_ptr, input_to_cell_weights_ptr, aux_input_ptr, aux_input_to_cell_weights_ptr,
340 output_state_ptr, recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
341 /*cell_to_gate_weights=*/nullptr, cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, n_batch,
342 n_input, n_aux_input, n_output, n_cell, params->activation, cell_gate_scratch,
343 is_input_all_zeros, is_aux_input_all_zeros);
344 // Update the cell state.
345 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
346 cell_gate_scratch, use_cifg, params->cell_clip);
347 // Calculate output gate.
348 CalculateLstmGateFloat(input_ptr, input_to_output_weights_ptr, aux_input_ptr,
349 aux_input_to_output_weights_ptr, output_state_ptr,
350 recurrent_to_output_weights_ptr, cell_state_ptr,
351 cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
352 output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
353 /*activation=kTfLiteActSigmoid*/ FusedActivationFunctionType::kSigmoid,
354 output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
355 // Update the output state.
356 CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
357 params->activation, projection_weights_ptr, projection_bias_ptr,
358 params->proj_clip, output_state_ptr, scratch2);
359 // Copy output state to the output. Note that the output's rows may not be
360 // contiguous (output_batch_leading_dim != n_output).
361 for (int b = 0; b < n_batch; b++)
362 {
363 std::copy_n(output_state_ptr + b * n_output, n_output,
364 output_ptr + b * output_batch_leading_dim);
365 }
366}
367
368} // namespace cker
369} // namespace nnfw
370
371#endif // __NNFW_CKER_UNIDIRECTIONALSEQUENCELSTM_H__
void VectorVectorCwiseProduct(const T *__restrict__ vector1, const T *__restrict__ vector2, int v_size, T *__restrict__ result)
Definition TensorUtils.h:52
void MeanStddevNormalization(const float *input_vector, float *output_vector, int v_size, int n_batch)
void VectorBatchVectorCwiseProduct(const T *vector, int v_size, const T *batch_vector, int n_batch, T *result)
Definition TensorUtils.h:76
void Sub1Vector(const float *vector, int v_size, float *result)
void MatrixBatchVectorMultiplyAccumulate(const int8_t *matrix, const int m_rows, const int m_cols, const int8_t *vector, const float *scaling_factors, int n_batch, float *result, int result_stride)
void ApplyActivationToVector(const float *vector, int v_size, FusedActivationFunctionType activation, float *result)
void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output, const float *cell_state, const float *output_gate, FusedActivationFunctionType activation, const float *projection_weights, const float *projection_bias, const float proj_clip, float *output_state, float *scratch)
Definition LSTM.h:183
void UpdateLstmCellFloat(int n_batch, int n_cell, float *cell_state, const float *input_gate, float *forget_gate, const float *cell_gate, bool use_cifg, float clip)
Definition LSTM.h:135
void VectorBatchVectorAssign(const float *vector, int v_size, int n_batch, float *batch_vector)
Definition TensorUtils.h:44
void CwiseClipping(float *vector, const int v_size, const float clipping_value)
Definition TensorUtils.h:34
void VectorVectorCwiseProductAccumulate(const T *__restrict__ vector1, const T *__restrict__ vector2, int v_size, T *__restrict__ result)
Definition TensorUtils.h:64
void VectorBatchVectorAdd(const float *vector, int v_size, int n_batch, float *batch_vector)
Definition TensorUtils.h:39
void CalculateLstmGateFloat(const float *input, const float *input_to_gate_weights, const float *aux_input, const float *aux_input_to_gate_weights, const float *output_state, const float *recurrent_to_gate_weights, const float *cell_state, const float *cell_to_gate_weights, const float *layer_norm_coefficients, const float *gate_bias, const int n_batch, const int n_input, const int n_aux_input, const int n_output, const int n_cell, const FusedActivationFunctionType activation, float *gate, const bool is_input_all_zeros, const bool is_aux_input_all_zeros)
Definition LSTM.h:62
void LstmStepFloat(const float *input_ptr, const float *input_to_input_weights_ptr, const float *input_to_forget_weights_ptr, const float *input_to_cell_weights_ptr, const float *input_to_output_weights_ptr, const float *aux_input_ptr, const float *aux_input_to_input_weights_ptr, const float *aux_input_to_forget_weights_ptr, const float *aux_input_to_cell_weights_ptr, const float *aux_input_to_output_weights_ptr, const float *recurrent_to_input_weights_ptr, const float *recurrent_to_forget_weights_ptr, const float *recurrent_to_cell_weights_ptr, const float *recurrent_to_output_weights_ptr, const float *cell_to_input_weights_ptr, const float *cell_to_forget_weights_ptr, const float *cell_to_output_weights_ptr, const float *input_layer_norm_coefficients_ptr, const float *forget_layer_norm_coefficients_ptr, const float *cell_layer_norm_coefficients_ptr, const float *output_layer_norm_coefficients_ptr, const float *input_gate_bias_ptr, const float *forget_gate_bias_ptr, const float *cell_gate_bias_ptr, const float *output_gate_bias_ptr, const float *projection_weights_ptr, const float *projection_bias_ptr, const LSTMParams *params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float *output_state_ptr, float *cell_state_ptr, float *scratch0, float *scratch1, float *scratch2, float *scratch3, float *output_ptr)
Definition LSTM.h:285
FusedActivationFunctionType
Definition Types.h:32
bool IsZeroVector(const float *vector, int v_size)
void VectorBatchVectorCwiseProductAccumulate(const T *vector, int v_size, const T *batch_vector, int n_batch, T *result)
Definition TensorUtils.h:92
Definition topk_v2.h:30
FusedActivationFunctionType activation
Definition Types.h:293