ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALUnidirectionalSequenceLSTM.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H
19#define LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H
20
21#include "arm_nnfunctions.h"
22#include "core/KernelParams.h"
23#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
24#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
25#include "fixedpoint/fixedpoint.h"
26
28{
29namespace lstm
30{
31
32inline cmsis_nn_lstm_params
33convert_lstm_params(const luci_interpreter::IntegerLSTMParams &params_in, bool time_major,
34 int32_t output_zeropoint, const int32_t *input_gate_bias,
35 const int32_t *forget_gate_bias, const int32_t *cell_gate_bias,
36 const int32_t *output_gate_bias, int16_t *input_layer_norm_coefficients,
37 int16_t *forget_layer_norm_coefficients, int16_t *cell_layer_norm_coefficients,
38 int16_t *output_layer_norm_coefficients)
39{
40 cmsis_nn_lstm_params params_out;
41
42 params_out.time_major = time_major;
43
44 // Multipliers and shifts for weights
45 params_out.input_to_input_scaling.multiplier = params_in.effective_input_to_input_scale_a;
46 params_out.input_to_input_scaling.shift = params_in.effective_input_to_input_scale_b;
47 params_out.recurrent_to_input_scaling.multiplier = params_in.effective_recurrent_to_input_scale_a;
48 params_out.recurrent_to_input_scaling.shift = params_in.effective_recurrent_to_input_scale_b;
49 params_out.cell_to_input_scaling.multiplier = params_in.effective_cell_to_input_scale_a;
50 params_out.cell_to_input_scaling.shift = params_in.effective_cell_to_input_scale_b;
51 params_out.input_to_forget_scaling.multiplier = params_in.effective_input_to_forget_scale_a;
52 params_out.input_to_forget_scaling.shift = params_in.effective_input_to_forget_scale_b;
53 params_out.recurrent_to_forget_scaling.multiplier =
54 params_in.effective_recurrent_to_forget_scale_a;
55 params_out.recurrent_to_forget_scaling.shift = params_in.effective_recurrent_to_forget_scale_b;
56 params_out.cell_to_forget_scaling.multiplier = params_in.effective_cell_to_forget_scale_a;
57 params_out.cell_to_forget_scaling.shift = params_in.effective_cell_to_forget_scale_b;
58 params_out.input_to_cell_scaling.multiplier = params_in.effective_input_to_cell_scale_a;
59 params_out.input_to_cell_scaling.shift = params_in.effective_input_to_cell_scale_b;
60 params_out.recurrent_to_cell_scaling.multiplier = params_in.effective_recurrent_to_cell_scale_a;
61 params_out.recurrent_to_cell_scaling.shift = params_in.effective_recurrent_to_cell_scale_b;
62 params_out.input_to_output_scaling.multiplier = params_in.effective_input_to_output_scale_a;
63 params_out.input_to_output_scaling.shift = params_in.effective_input_to_output_scale_b;
64
65 params_out.recurrent_to_output_scaling.multiplier =
66 params_in.effective_recurrent_to_output_scale_a;
67 params_out.recurrent_to_output_scaling.shift = params_in.effective_recurrent_to_output_scale_b;
68 params_out.cell_to_output_scaling.multiplier = params_in.effective_cell_to_output_scale_a;
69 params_out.cell_to_output_scaling.shift = params_in.effective_cell_to_output_scale_b;
70 params_out.projection_scaling.multiplier = params_in.effective_proj_scale_a;
71 params_out.projection_scaling.shift = params_in.effective_proj_scale_b;
72
73 params_out.layer_norm_input_scaling.multiplier = params_in.layer_norm_input_scale_a;
74 params_out.layer_norm_input_scaling.shift = params_in.layer_norm_input_scale_b;
75 params_out.layer_norm_forget_scaling.multiplier = params_in.layer_norm_forget_scale_a;
76 params_out.layer_norm_forget_scaling.shift = params_in.layer_norm_forget_scale_b;
77 params_out.layer_norm_cell_scaling.multiplier = params_in.layer_norm_cell_scale_a;
78 params_out.layer_norm_cell_scaling.shift = params_in.layer_norm_cell_scale_b;
79 params_out.layer_norm_output_scaling.multiplier = params_in.layer_norm_output_scale_a;
80 params_out.layer_norm_output_scaling.shift = params_in.layer_norm_output_scale_b;
81
82 params_out.clip.cell = params_in.quantized_cell_clip;
83 params_out.clip.projection = params_in.quantized_proj_clip;
84
85 params_out.cell_state_shift = params_in.cell_scale;
86
87 params_out.hidden_offset = params_in.hidden_zp;
88 params_out.output_state_offset = output_zeropoint;
89
90 params_out.guard.input_variance = params_in.input_variance_guard;
91 params_out.guard.forget_variance = params_in.forget_variance_guard;
92 params_out.guard.cell_variance = params_in.cell_variance_guard;
93 params_out.guard.output_variance = params_in.output_variance_guard;
94
95 params_out.i2f_effective_bias = params_in.input_to_forget_effective_bias.data();
96 params_out.r2f_effective_bias = params_in.recurrent_to_forget_effective_bias.data();
97 params_out.i2c_effective_bias = params_in.input_to_cell_effective_bias.data();
98 params_out.r2c_effective_bias = params_in.recurrent_to_cell_effective_bias.data();
99 params_out.i2o_effective_bias = params_in.input_to_output_effective_bias.data();
100 params_out.r2o_effective_bias = params_in.recurrent_to_output_effective_bias.data();
101 params_out.i2i_effective_bias = params_in.input_to_input_effective_bias.data();
102 params_out.r2i_effective_bias = params_in.recurrent_to_input_effective_bias.data();
103 params_out.projection_effective_bias = params_in.projection_effective_bias.data();
104
105 params_out.hidden_scaling.multiplier = params_in.effective_hidden_scale_a;
106 params_out.hidden_scaling.shift = params_in.effective_hidden_scale_b;
107
108 params_out.input_gate_bias = input_gate_bias;
109 params_out.forget_gate_bias = forget_gate_bias;
110 params_out.cell_gate_bias = cell_gate_bias;
111 params_out.output_gate_bias = output_gate_bias;
112
113 params_out.layer_norm.input_weight = input_layer_norm_coefficients;
114 params_out.layer_norm.forget_weight = forget_layer_norm_coefficients;
115 params_out.layer_norm.cell_weight = cell_layer_norm_coefficients;
116 params_out.layer_norm.output_weight = output_layer_norm_coefficients;
117
118 params_out.activation.min = std::numeric_limits<int16_t>::min();
119 params_out.activation.max = std::numeric_limits<int16_t>::max();
120
121 return params_out;
122}
123
124} // namespace lstm
125
127 const luci_interpreter::Tensor *input, const luci_interpreter::Tensor *input_to_input_weights,
128 const luci_interpreter::Tensor *input_to_forget_weights,
129 const luci_interpreter::Tensor *input_to_cell_weights,
130 const luci_interpreter::Tensor *input_to_output_weights,
131 const luci_interpreter::Tensor *recurrent_to_input_weights,
132 const luci_interpreter::Tensor *recurrent_to_forget_weights,
133 const luci_interpreter::Tensor *recurrent_to_cell_weights,
134 const luci_interpreter::Tensor *recurrent_to_output_weights,
135 const luci_interpreter::Tensor *cell_to_input_weights,
136 const luci_interpreter::Tensor *cell_to_forget_weights,
137 const luci_interpreter::Tensor *cell_to_output_weights,
138 const luci_interpreter::Tensor *input_layer_norm_coefficients,
139 const luci_interpreter::Tensor *forget_layer_norm_coefficients,
140 const luci_interpreter::Tensor *cell_layer_norm_coefficients,
141 const luci_interpreter::Tensor *output_layer_norm_coefficients,
142 const luci_interpreter::Tensor *input_gate_bias, const luci_interpreter::Tensor *forget_gate_bias,
143 const luci_interpreter::Tensor *cell_gate_bias, const luci_interpreter::Tensor *output_gate_bias,
144 const luci_interpreter::Tensor *projection_weights,
145 const luci_interpreter::Tensor *projection_bias,
146 const luci_interpreter::UnidirectionalSequenceLSTMParams &params, bool forward_sequence,
147 bool time_major, const luci_interpreter::IntegerLSTMParams &integer_lstm_param,
148 int32_t output_state_zp, luci_interpreter::Tensor *output_state,
149 luci_interpreter::Tensor *cell_state, luci_interpreter::Tensor *output, int16_t *scratch0,
150 int16_t *scratch1, int16_t *scratch2, int16_t *scratch3, int8_t *scratch4, int32_t *scratch5)
151{
152 // CMSIS-NN does not support these configurations currently.
153 // Please use MCU kernels instead
154 const bool use_layer_norm = (forget_layer_norm_coefficients != nullptr);
155 const bool use_peephole = (cell_to_output_weights != nullptr);
156 const bool use_projection = (projection_weights != nullptr);
157 const bool use_cifg = (input_to_input_weights == nullptr);
158 const bool unsupported_config = use_layer_norm || use_peephole || use_projection || use_cifg;
159
160 if (unsupported_config)
161 {
162 assert(false && "CMSIS-NN does not support these configurations currently");
163 return;
164 }
165
166 const auto input_shape = input->shape();
167 LUCI_INTERPRETER_CHECK(input_shape.num_dims() >= 2 && input_shape.num_dims() <= 3);
168
169 cmsis_nn_lstm_context scratch_buffers;
170 scratch_buffers.input_gate = scratch0;
171 scratch_buffers.forget_gate = scratch1;
172 scratch_buffers.cell_gate = scratch2;
173 scratch_buffers.output_gate = scratch3;
174 scratch_buffers.scratch = scratch4;
175
176 cmsis_nn_lstm_params cmsis_lstm_params = lstm::convert_lstm_params(
177 integer_lstm_param, time_major, output_state_zp,
178 luci_interpreter::kernels::getTensorData<int32_t>(input_gate_bias),
179 luci_interpreter::kernels::getTensorData<int32_t>(forget_gate_bias),
180 luci_interpreter::kernels::getTensorData<int32_t>(cell_gate_bias),
181 luci_interpreter::kernels::getTensorData<int32_t>(output_gate_bias),
182 const_cast<int16_t *>(
183 luci_interpreter::kernels::getTensorData<int16_t>(input_layer_norm_coefficients)),
184 const_cast<int16_t *>(
185 luci_interpreter::kernels::getTensorData<int16_t>(forget_layer_norm_coefficients)),
186 const_cast<int16_t *>(
187 luci_interpreter::kernels::getTensorData<int16_t>(cell_layer_norm_coefficients)),
188 const_cast<int16_t *>(
189 luci_interpreter::kernels::getTensorData<int16_t>(output_layer_norm_coefficients)));
190
191 const int n_input = input_shape.dim(input_shape.num_dims() - 1);
192 int max_time, n_batch;
193 if (input_shape.num_dims() == 2)
194 {
195 max_time = 1;
196 n_batch = input_shape.dim(0);
197 }
198 else
199 {
200 max_time = (time_major) ? input_shape.dim(0) : input_shape.dim(1);
201 n_batch = (time_major) ? input_shape.dim(1) : input_shape.dim(0);
202 }
203
204 // n_cell and n_output will be the same size when there is no projection.
205 const int n_cell = input_to_output_weights->shape().dim(0);
206 const int n_output = recurrent_to_output_weights->shape().dim(1);
207
208 cmsis_nn_lstm_dims lstm_dims;
209 lstm_dims.num_inputs = n_input;
210 lstm_dims.num_outputs = n_output;
211 lstm_dims.num_batches = n_batch;
212 lstm_dims.max_time = max_time;
213
214 arm_lstm_unidirectional_s16_s8(
215 &scratch_buffers, const_cast<int8_t *>(luci_interpreter::kernels::getTensorData<int8_t>(input)),
216 &lstm_dims,
217 const_cast<int8_t *>(luci_interpreter::kernels::getTensorData<int8_t>(input_to_input_weights)),
218 const_cast<int8_t *>(luci_interpreter::kernels::getTensorData<int8_t>(input_to_forget_weights)),
219 const_cast<int8_t *>(luci_interpreter::kernels::getTensorData<int8_t>(input_to_cell_weights)),
220 const_cast<int8_t *>(luci_interpreter::kernels::getTensorData<int8_t>(input_to_output_weights)),
221 const_cast<int8_t *>(
222 luci_interpreter::kernels::getTensorData<int8_t>(recurrent_to_input_weights)),
223 const_cast<int8_t *>(
224 luci_interpreter::kernels::getTensorData<int8_t>(recurrent_to_forget_weights)),
225 const_cast<int8_t *>(
226 luci_interpreter::kernels::getTensorData<int8_t>(recurrent_to_cell_weights)),
227 const_cast<int8_t *>(
228 luci_interpreter::kernels::getTensorData<int8_t>(recurrent_to_output_weights)),
229 const_cast<int16_t *>(luci_interpreter::kernels::getTensorData<int16_t>(cell_to_input_weights)),
230 const_cast<int16_t *>(
231 luci_interpreter::kernels::getTensorData<int16_t>(cell_to_forget_weights)),
232 const_cast<int16_t *>(
233 luci_interpreter::kernels::getTensorData<int16_t>(cell_to_output_weights)),
234 const_cast<int8_t *>(luci_interpreter::kernels::getTensorData<int8_t>(projection_weights)),
235 &cmsis_lstm_params,
236 const_cast<int8_t *>(luci_interpreter::kernels::getTensorData<int8_t>(output_state)),
237 const_cast<int16_t *>(luci_interpreter::kernels::getTensorData<int16_t>(cell_state)),
238 const_cast<int8_t *>(luci_interpreter::kernels::getTensorData<int8_t>(output)));
239}
240
241} // namespace luci_interpreter_pal
242
243#endif // LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H
int32_t dim(int i) const
Definition Tensor.h:41
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
cmsis_nn_lstm_params convert_lstm_params(const luci_interpreter::IntegerLSTMParams &params_in, bool time_major, int32_t output_zeropoint, const int32_t *input_gate_bias, const int32_t *forget_gate_bias, const int32_t *cell_gate_bias, const int32_t *output_gate_bias, int16_t *input_layer_norm_coefficients, int16_t *forget_layer_norm_coefficients, int16_t *cell_layer_norm_coefficients, int16_t *output_layer_norm_coefficients)
void eval_integer_8x8_16_lstm(const luci_interpreter::Tensor *input, const luci_interpreter::Tensor *input_to_input_weights, const luci_interpreter::Tensor *input_to_forget_weights, const luci_interpreter::Tensor *input_to_cell_weights, const luci_interpreter::Tensor *input_to_output_weights, const luci_interpreter::Tensor *recurrent_to_input_weights, const luci_interpreter::Tensor *recurrent_to_forget_weights, const luci_interpreter::Tensor *recurrent_to_cell_weights, const luci_interpreter::Tensor *recurrent_to_output_weights, const luci_interpreter::Tensor *cell_to_input_weights, const luci_interpreter::Tensor *cell_to_forget_weights, const luci_interpreter::Tensor *cell_to_output_weights, const luci_interpreter::Tensor *input_layer_norm_coefficients, const luci_interpreter::Tensor *forget_layer_norm_coefficients, const luci_interpreter::Tensor *cell_layer_norm_coefficients, const luci_interpreter::Tensor *output_layer_norm_coefficients, const luci_interpreter::Tensor *input_gate_bias, const luci_interpreter::Tensor *forget_gate_bias, const luci_interpreter::Tensor *cell_gate_bias, const luci_interpreter::Tensor *output_gate_bias, const luci_interpreter::Tensor *projection_weights, const luci_interpreter::Tensor *projection_bias, const luci_interpreter::UnidirectionalSequenceLSTMParams &params, bool forward_sequence, bool time_major, const luci_interpreter::IntegerLSTMParams &integer_lstm_param, int32_t output_state_zp, luci_interpreter::Tensor *output_state, luci_interpreter::Tensor *cell_state, luci_interpreter::Tensor *output, int16_t *scratch0, int16_t *scratch1, int16_t *scratch2, int16_t *scratch3, int8_t *scratch4, int32_t *scratch5)