ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALSoftmax.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_SOFTMAX_H
19#define LUCI_INTERPRETER_PAL_SOFTMAX_H
20
21#include "PALSoftmaxCommon.h"
22
23#include <arm_nnfunctions.h>
24
25namespace
26{
27const int kInt16LUTArraySize = 513;
28
29// LUTPopulate takes an optional type-erased transform_params to allow passing
30// extra parameters to the transform function pointer. const void* is used
31// instead of std::function to be compatible with TFLite Micro
32template <typename FloatT, typename Func>
33inline typename std::enable_if<std::is_same<Func, FloatT (*)(FloatT)>::value, FloatT>::type
34LUTTransform(Func transform, const void * /*transform_params*/, FloatT value)
35{
36 static_assert(std::is_floating_point<FloatT>::value, "FloatT must be a floating-point type.");
37 return transform(value);
38}
39
40// Keep floating-point type configurable for backward compatibility. float
41// should be used for FloatT by default.
42template <typename FloatT, typename Func>
43inline void LUTPopulateInt16(FloatT input_scale, int32_t input_zero_point, FloatT output_scale,
44 int32_t output_zero_point, Func transform,
45 const void *transform_params, int16_t *lut)
46{
47 static_assert(std::is_floating_point<FloatT>::value, "FloatT must be a floating-point type.");
48 const FloatT input_min = input_scale * (std::numeric_limits<int16_t>::min() - input_zero_point);
49 const FloatT input_max = input_scale * (std::numeric_limits<int16_t>::max() - input_zero_point);
50 const FloatT output_min =
51 output_scale * (std::numeric_limits<int16_t>::min() - output_zero_point);
52 const FloatT output_max =
53 output_scale * (std::numeric_limits<int16_t>::max() - output_zero_point);
54
55 const int nb_steps = 512;
56 const FloatT step = (input_max - input_min) / nb_steps;
57 const FloatT half_step = step / 2;
58 const FloatT output_scaling_inv = static_cast<FloatT>(std::numeric_limits<int16_t>::max() -
59 std::numeric_limits<int16_t>::min() + 1) /
60 (output_max - output_min);
61 const FloatT table_min = static_cast<FloatT>(std::numeric_limits<int16_t>::min());
62 const FloatT table_max = static_cast<FloatT>(std::numeric_limits<int16_t>::max());
63
64 for (int i = 0; i < nb_steps; i++)
65 {
66 const FloatT val = LUTTransform<FloatT>(transform, transform_params, input_min + i * step);
67 const FloatT val_midpoint =
68 LUTTransform<FloatT>(transform, transform_params, input_min + i * step + half_step);
69 const FloatT val_next =
70 LUTTransform<FloatT>(transform, transform_params, input_min + (i + 1) * step);
71
72 const FloatT sample_val = std::round(val * output_scaling_inv);
73 const FloatT midpoint_interp_val =
74 std::round((val_next * output_scaling_inv + std::round(val * output_scaling_inv)) / 2);
75 const FloatT midpoint_val = std::round(val_midpoint * output_scaling_inv);
76 const FloatT midpoint_err = midpoint_interp_val - midpoint_val;
77 const FloatT bias = std::round(midpoint_err / 2);
78
79 lut[i] = static_cast<int16_t>(
80 std::min<FloatT>(std::max<FloatT>(sample_val - bias, table_min), table_max));
81 }
82
83 lut[nb_steps] = static_cast<int16_t>(std::min<FloatT>(
84 std::max<FloatT>(
85 std::round(LUTTransform<FloatT>(transform, transform_params, input_max) * output_scaling_inv),
86 table_min),
87 table_max));
88}
89
90template <typename T>
91inline typename std::enable_if<std::is_same<T, int16_t>::value, void>::type
92LUTPopulate(float input_scale, int32_t input_zero_point, float output_scale,
93 int32_t output_zero_point, float (*transform)(float), T *lut)
94{
95 LUTPopulateInt16<float>(input_scale, input_zero_point, output_scale, output_zero_point, transform,
96 nullptr, lut);
97}
98
99} // namespace
100
101namespace luci_interpreter_pal
102{
103
104inline void Softmax(const SoftmaxParams &params, const int8_t *input_data, int8_t *output_data)
105{
106 arm_softmax_s8(input_data, params.num_rows, params.row_size, params.input_multiplier,
107 params.input_left_shift, params.diff_min, output_data);
108}
109
110inline void Softmax(const SoftmaxParams &params, const int8_t *input_data, int16_t *output_data)
111{
112 arm_softmax_s8_s16(input_data, params.num_rows, params.row_size, params.input_multiplier,
113 params.input_left_shift, params.diff_min, output_data);
114}
115
116inline void Softmax(const SoftmaxParams &params, const int16_t *input_data, int16_t *output_data)
117{
118 cmsis_nn_softmax_lut_s16 softmax_params{};
119
120 auto raw_exp_lut = std::make_unique<int16_t[]>(kInt16LUTArraySize);
121 auto one_over_one_plus_x_lut = std::make_unique<int16_t[]>(kInt16LUTArraySize);
122
123 // exp LUT only used on negative values
124 // we consider exp(-10.0) is insignificant to accumulation
125 const int32_t range = std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min();
126
127 LUTPopulate<int16_t>(
128 10.0f / range, std::numeric_limits<int16_t>::max(), 2.0f / range, 0,
129 [](float value) { return std::exp(value); }, raw_exp_lut.get());
130
131 LUTPopulate<int16_t>(
132 1.0f / range, std::numeric_limits<int16_t>::min(), 2.0f / range, 0,
133 [](float value) { return 1.0f / (1.0f + value); }, one_over_one_plus_x_lut.get());
134
135 softmax_params.exp_lut = raw_exp_lut.get();
136 softmax_params.one_by_one_lut = one_over_one_plus_x_lut.get();
137
138 arm_softmax_s16(input_data, params.num_rows, params.row_size, params.input_multiplier,
139 params.input_left_shift, &softmax_params, output_data);
140}
141
142} // namespace luci_interpreter_pal
143
144#endif // LUCI_INTERPRETER_PAL_SOFTMAX_H