ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALUnidirectionalSequenceLSTM.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H
19#define LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H
20
22
23#ifndef DIS_QUANT
24
26{
27// Evaluate the LSTM kernel with (potential) multi-steps and multi-batch input
28template <>
32 luci_interpreter::lstm::CellStateInfo *cell_state_info, int8_t *output_state_data,
33 int16_t *cell_state_data, int16_t *scratch0, int16_t *scratch1, int16_t *scratch2,
34 int16_t *scratch3, luci_interpreter::BaseRuntimeGraph *runtime_graph)
35{
37
38 size_info.time_major = lstm_struct->options->time_major();
39 size_info.batch_size = size_info.time_major
40 ? luci_interpreter::Tensor::dim(lstm_struct->input(), 1)
41 : luci_interpreter::Tensor::dim(lstm_struct->input(), 0);
42 size_info.time_steps = size_info.time_major
43 ? luci_interpreter::Tensor::dim(lstm_struct->input(), 0)
44 : luci_interpreter::Tensor::dim(lstm_struct->input(), 1);
45 size_info.input_dimension = luci_interpreter::Tensor::dim(lstm_struct->input(), 2);
46 size_info.state_dimension = luci_interpreter::Tensor::dim(lstm_struct->output_state(), 1);
47
48 lstm_internal::LstmStepManager step_info(size_info);
49
50 // time is the first dimention, enable batch computation
51 if (size_info.time_major)
52 {
53 for (int t = 0; t < size_info.time_steps; t++)
54 {
55 lstm_internal::lstmStep<int8_t, int8_t, int16_t, int32_t>(
56 lstm_struct, lstm_params, &step_info, cell_state_info, output_state_data, cell_state_data,
57 scratch0, scratch1, scratch2, scratch3, runtime_graph);
58 // prepare for the next time step
59 step_info.updateTime();
60 }
61 }
62 else
63 {
64 // batch first, unable to size the input data. single batch inference
65 for (int b = 0; b < size_info.batch_size; b++)
66 {
67 for (int t = 0; t < size_info.time_steps; t++)
68 {
69 lstm_internal::lstmStep<int8_t, int8_t, int16_t, int32_t>(
70 lstm_struct, lstm_params, &step_info, cell_state_info, output_state_data, cell_state_data,
71 scratch0, scratch1, scratch2, scratch3, runtime_graph);
72 // prepare for the next time step
73 step_info.updateTime();
74 }
75 // prepare for the next batch
76 step_info.updateBatch();
77 step_info.resetTime();
78 }
79 }
80}
81
82} // namespace luci_interpreter_pal
83
84#endif // DIS_QUANT
85
86#endif // LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H
void evalLSTM< int8_t, int8_t, int16_t, int32_t >(luci_interpreter::lstm::LSTMStruct *lstm_struct, luci_interpreter::lstm::LSTMParameters *lstm_params, luci_interpreter::lstm::CellStateInfo *cell_state_info, int8_t *output_state_data, int16_t *cell_state_data, int16_t *scratch0, int16_t *scratch1, int16_t *scratch2, int16_t *scratch3, luci_interpreter::BaseRuntimeGraph *runtime_graph)
const loco::Dimension & dim(uint32_t axis) const
Definition Tensor.h:44
const circle::UnidirectionalSequenceLSTMOptions * options