ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALTanh.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_TANH_H
19#define LUCI_INTERPRETER_PAL_TANH_H
20
21#include "PALUtils.h"
22
24{
25
26inline void Tanh(const int flat_size, const float *input_data, float *output_data)
27{
28 for (int i = 0; i < flat_size; i++)
29 {
30 float val = input_data[i];
31 float result = std::tanh(val);
32 output_data[i] = result;
33 }
34}
35
36inline void Tanh(int32_t input_multiplier, int32_t input_left_shift, const int flat_size,
37 const int16_t *ptr_input_data, int16_t *ptr_output_data)
38{
39 // We use the LUT for sigmoid and take into account, that
40 // tanh(x) = 2*sigmoid(2*x) - 1
41
42 // We scale by 3/4 to expand range [-8,8]->[-10.7,10.7].
43 // In case of general parameter scale, multiplier 3 is taken into account
44 // in TanhPrepare function and it is included in
45 // input_multiplier already.
46
47 if (input_multiplier == 0)
48 { // power of two case
49 input_multiplier = 3 << input_left_shift;
50 input_left_shift = 0;
51 }
52
53 int32_t round = (input_left_shift > 0) ? 1 << (input_left_shift - 1) : 0;
54
55 for (int i = 0; i < flat_size; ++i, ptr_input_data++, ptr_output_data++)
56 {
57 int32_t input_data = ((*ptr_input_data) * input_multiplier + round) >> input_left_shift;
58
59 uint32_t abs_input_data = abs(input_data);
60 uint32_t uh = abs_input_data >> 8;
61 int32_t result;
62
63 if (uh >= 255)
64 {
65 // Saturate to maximum.
66 result = 0xFFFF << 8;
67 }
68 else
69 {
70 uint32_t ua = sigmoid_table_uint16[uh];
71 uint32_t ub = sigmoid_table_uint16[uh + 1];
72
73 uint8_t ut = abs_input_data & 0xFF;
74
75 result = (ua << 8) + ut * (ub - ua);
76 }
77
78 result = (input_data >= 0) ? (result - (1 << (14 + 9)) + (1 << (9 - 2)))
79 : (-result + (1 << (14 + 9)) + (1 << (9 - 2)) - 1);
80
81 // Convert back to 16-bit.
82 result >>= (9 - 1);
83
84 *ptr_output_data = result;
85 }
86}
87
88#if 0
89inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
90 int32_t input_multiplier, int32_t input_shift,
91 const int flat_size, const int8_t* input_data, int8_t* output_data) {
92 // Integer bits must be in sync with Prepare() function.
93 static constexpr int32_t kInputIntegerBits = 4;
94 static constexpr int32_t kOutputScale = 7;
95 static constexpr int32_t kMinInt8 = std::numeric_limits<int8_t>::min();
96 static constexpr int32_t kMaxInt8 = std::numeric_limits<int8_t>::max();
97
98 for (int i = 0; i < flat_size; ++i) {
99 const int32_t input =
100 static_cast<int32_t>(input_data[i]) - input_zero_point;
101 if (input <= -input_range_radius) {
102 output_data[i] = kMinInt8;
103 } else if (input >= input_range_radius) {
104 output_data[i] = kMaxInt8;
105 } else {
106 const int32_t input_in_q4 =
107 multiplyByQuantizedMultiplier(input, input_multiplier, input_shift);
108 const int32_t output_in_q0 = std::tanh(input_in_q4);
109
110 int32_t output_in_q24 =
111 roundingDivideByPOT(output_in_q0, 31 - kOutputScale);
112 output_in_q24 = std::min(std::max(output_in_q24, kMinInt8), kMaxInt8);
113 output_data[i] = static_cast<int8_t>(output_in_q24);
114 }
115 }
116}
117#endif // 0
118
119} // namespace luci_interpreter_pal
120
121#endif // LUCI_INTERPRETER_PAL_TANH_H
void Tanh(const int flat_size, const float *input_data, float *output_data)
Definition PALTanh.h:26
int32_t roundingDivideByPOT(int32_t x, int32_t exponent)
Definition PALUtils.h:65
int32_t multiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
Definition PALUtils.h:77