ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALGRU.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LUCI_INTERPRETER_PAL_GRU_H
18#define LUCI_INTERPRETER_PAL_GRU_H
19
20#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
21#include "PALreference_ops.h"
23{
24
25// tflite's Logistic does not provide inplace Logistic kernel
26void Logistic(const int flat_size, const float *input_data, float *output_data)
27{
28 const float cutoff_upper = 16.619047164916992188f;
29 const float cutoff_lower = -9.f;
30
31 // Rational for using approximation in reference kernel.
32 // 0. This approximation gives enough precision for float.
33 // 1. This works around an issue on an embedded chipset where exp() does not
34 // return correctly as expected - exp(x) should return inf when overflown
35 // not 1.701417 IEEE 754 defines representation for inf.
36 // 2. This will speed up calculation and is matching the behavior in the
37 // optimized kernels. (check the definition of scalar_logistic_op<float>)
38
39 for (int i = 0; i < flat_size; i++)
40 {
41 float val = input_data[i];
42 float result;
43 if (val > cutoff_upper)
44 {
45 result = 1.0f;
46 }
47 else if (val < cutoff_lower)
48 {
49 result = std::exp(val);
50 }
51 else
52 {
53 result = 1.f / (1.f + std::exp(-val));
54 }
55 output_data[i] = result;
56 }
57}
58
59void calculateGRU(const float *input_data, const float *weight_input_data,
60 const float *weight_hidden_data, const float *bias_input_data,
61 const float *bias_hidden_data, float *output_data,
62 const tflite::RuntimeShape &input_shape, const tflite::RuntimeShape &output_shape,
63 const tflite::RuntimeShape &weight_input_shape,
64 const tflite::RuntimeShape &weight_hidden_shape, float *output_input_data,
65 float *output_hidden_data, const tflite::RuntimeShape &output_shape_fc)
66{
67 tflite::FullyConnectedParams op_params{};
68 // As FC nodes doesn't have any activations inside GRU, let' use just numeric limits
69 op_params.float_activation_min = std::numeric_limits<float>::lowest();
70 op_params.float_activation_max = std::numeric_limits<float>::max();
71
72 // FC Input
73 tflite::RuntimeShape bias_input_shape{weight_input_shape.Dims(0)};
74 tflite::reference_ops::FullyConnected(op_params, output_shape, output_data, weight_input_shape,
75 weight_input_data, bias_input_shape, bias_input_data,
76 output_shape_fc, output_input_data);
77
78 // FC Hidden
79 tflite::RuntimeShape bias_hidden_shape{weight_hidden_shape.Dims(0)};
80 // Note: input for this FC node will be saved without intermediate buffer
81 tflite::reference_ops::FullyConnected(op_params, input_shape, input_data, weight_hidden_shape,
82 weight_hidden_data, bias_hidden_shape, bias_hidden_data,
83 output_shape_fc, output_hidden_data);
84
85 int num_elements = output_shape_fc.Dims(1) / 3;
86
87 float *second_hidden_part = output_hidden_data + num_elements;
88 float *second_input_part = output_input_data + num_elements;
89
90 float *third_hidden_part = second_hidden_part + num_elements;
91 float *third_input_part = second_input_part + num_elements;
92
93 // Calculate Left part
94 for (int i = 0; i < num_elements; ++i)
95 {
96 output_input_data[i] += output_hidden_data[i];
97 }
98
99 Logistic(num_elements, output_input_data, output_input_data);
100
101 // Calculate most left mul
102 float *most_left_part_final = output_input_data;
103 float *first_part = output_input_data;
104 for (int i = 0; i < num_elements; ++i)
105 {
106 output_data[i] *= most_left_part_final[i];
107 first_part[i] = 1.0f - first_part[i];
108 }
109
110 // Calc second part
111 for (int i = 0; i < num_elements; ++i)
112 {
113 second_hidden_part[i] += second_input_part[i];
114 }
115
116 Logistic(num_elements, second_hidden_part, second_hidden_part);
117
118 for (int i = 0; i < num_elements; ++i)
119 {
120 second_hidden_part[i] *= third_input_part[i];
121 second_hidden_part[i] += third_hidden_part[i];
122 }
123
124 for (int i = 0; i < num_elements; ++i)
125 {
126 if (second_hidden_part[i] > 19)
127 {
128 second_hidden_part[i] = 1;
129 }
130 else if (second_hidden_part[i] < -19)
131 {
132 second_hidden_part[i] = -1;
133 }
134 else
135 {
136 second_hidden_part[i] = std::tanh(second_hidden_part[i]);
137 }
138 }
139
140 for (int i = 0; i < num_elements; ++i)
141 {
142 second_hidden_part[i] *= first_part[i];
143 output_data[i] += second_hidden_part[i];
144 }
145}
146
147void GRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data,
148 const float *bias_input_data, const float *bias_hidden_data,
149 const float *hidden_state_data, float *output_data, float *output_input_data,
150 float *output_hidden_data, const tflite::RuntimeShape &input_shape,
151 const tflite::RuntimeShape &output_shape, const tflite::RuntimeShape &weight_input_shape,
152 const tflite::RuntimeShape &weight_hidden_shape)
153{
154 const int32_t time = input_shape.Dims(0);
155
156 tflite::RuntimeShape output_shape_fc(2);
157 output_shape_fc.SetDim(0, 1);
158 output_shape_fc.SetDim(1, weight_hidden_shape.Dims(0));
159
160 std::memcpy(output_data, hidden_state_data, output_shape.FlatSize() * sizeof(float));
161
162 for (int i = 0; i < time; ++i)
163 {
164 calculateGRU(input_data, weight_input_data, weight_hidden_data, bias_input_data,
165 bias_hidden_data, output_data, input_shape, output_shape, weight_input_shape,
166 weight_hidden_shape, output_input_data, output_hidden_data, output_shape_fc);
167 input_data += input_shape.Dims(2);
168 }
169}
170
171} // namespace luci_interpreter_pal
172
173#endif // LUCI_INTERPRETER_PAL_GRU_H
const luci_interpreter::RuntimeShape output_shape
void GRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data, const float *bias_input_data, const float *bias_hidden_data, const float *hidden_state_data, float *output_data, float *output_input_data, float *output_hidden_data, const tflite::RuntimeShape &input_shape, const tflite::RuntimeShape &output_shape, const tflite::RuntimeShape &weight_input_shape, const tflite::RuntimeShape &weight_hidden_shape)
Definition PALGRU.h:147
void calculateGRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data, const float *bias_input_data, const float *bias_hidden_data, float *output_data, const tflite::RuntimeShape &input_shape, const tflite::RuntimeShape &output_shape, const tflite::RuntimeShape &weight_input_shape, const tflite::RuntimeShape &weight_hidden_shape, float *output_input_data, float *output_hidden_data, const tflite::RuntimeShape &output_shape_fc)
Definition PALGRU.h:59
void Logistic(const int flat_size, const float *input_data, float *output_data)
Definition PALGRU.h:26