18#ifndef LUCI_INTERPRETER_PAL_SOFTMAX_H
19#define LUCI_INTERPRETER_PAL_SOFTMAX_H
21#include "PALSoftmaxCommon.h"
23#include <arm_nnfunctions.h>
27const int kInt16LUTArraySize = 513;
32template <
typename FloatT,
typename Func>
33inline typename std::enable_if<std::is_same<Func, FloatT (*)(FloatT)>::value, FloatT>::type
34LUTTransform(Func transform,
const void * , FloatT value)
36 static_assert(std::is_floating_point<FloatT>::value,
"FloatT must be a floating-point type.");
37 return transform(value);
42template <
typename FloatT,
typename Func>
43inline void LUTPopulateInt16(FloatT input_scale, int32_t input_zero_point, FloatT output_scale,
44 int32_t output_zero_point, Func transform,
45 const void *transform_params, int16_t *lut)
47 static_assert(std::is_floating_point<FloatT>::value,
"FloatT must be a floating-point type.");
48 const FloatT input_min = input_scale * (std::numeric_limits<int16_t>::min() - input_zero_point);
49 const FloatT input_max = input_scale * (std::numeric_limits<int16_t>::max() - input_zero_point);
50 const FloatT output_min =
51 output_scale * (std::numeric_limits<int16_t>::min() - output_zero_point);
52 const FloatT output_max =
53 output_scale * (std::numeric_limits<int16_t>::max() - output_zero_point);
55 const int nb_steps = 512;
56 const FloatT step = (input_max - input_min) / nb_steps;
57 const FloatT half_step = step / 2;
58 const FloatT output_scaling_inv =
static_cast<FloatT
>(std::numeric_limits<int16_t>::max() -
59 std::numeric_limits<int16_t>::min() + 1) /
60 (output_max - output_min);
61 const FloatT table_min =
static_cast<FloatT
>(std::numeric_limits<int16_t>::min());
62 const FloatT table_max =
static_cast<FloatT
>(std::numeric_limits<int16_t>::max());
64 for (
int i = 0; i < nb_steps; i++)
66 const FloatT val = LUTTransform<FloatT>(transform, transform_params, input_min + i * step);
67 const FloatT val_midpoint =
68 LUTTransform<FloatT>(transform, transform_params, input_min + i * step + half_step);
69 const FloatT val_next =
70 LUTTransform<FloatT>(transform, transform_params, input_min + (i + 1) * step);
72 const FloatT sample_val = std::round(val * output_scaling_inv);
73 const FloatT midpoint_interp_val =
74 std::round((val_next * output_scaling_inv + std::round(val * output_scaling_inv)) / 2);
75 const FloatT midpoint_val = std::round(val_midpoint * output_scaling_inv);
76 const FloatT midpoint_err = midpoint_interp_val - midpoint_val;
77 const FloatT
bias = std::round(midpoint_err / 2);
79 lut[i] =
static_cast<int16_t
>(
80 std::min<FloatT>(std::max<FloatT>(sample_val - bias, table_min), table_max));
83 lut[nb_steps] =
static_cast<int16_t
>(std::min<FloatT>(
85 std::round(LUTTransform<FloatT>(transform, transform_params, input_max) * output_scaling_inv),
91inline typename std::enable_if<std::is_same<T, int16_t>::value,
void>::type
92LUTPopulate(
float input_scale, int32_t input_zero_point,
float output_scale,
93 int32_t output_zero_point,
float (*transform)(
float), T *lut)
95 LUTPopulateInt16<float>(input_scale, input_zero_point, output_scale, output_zero_point, transform,
104inline void Softmax(
const SoftmaxParams ¶ms,
const int8_t *input_data, int8_t *output_data)
110inline void Softmax(
const SoftmaxParams ¶ms,
const int8_t *input_data, int16_t *output_data)
116inline void Softmax(
const SoftmaxParams ¶ms,
const int16_t *input_data, int16_t *output_data)
118 cmsis_nn_softmax_lut_s16 softmax_params{};
120 auto raw_exp_lut = std::make_unique<int16_t[]>(kInt16LUTArraySize);
121 auto one_over_one_plus_x_lut = std::make_unique<int16_t[]>(kInt16LUTArraySize);
125 const int32_t range = std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min();
127 LUTPopulate<int16_t>(
128 10.0f / range, std::numeric_limits<int16_t>::max(), 2.0f / range, 0,
129 [](
float value) {
return std::exp(value); }, raw_exp_lut.get());
131 LUTPopulate<int16_t>(
132 1.0f / range, std::numeric_limits<int16_t>::min(), 2.0f / range, 0,
133 [](
float value) {
return 1.0f / (1.0f + value); }, one_over_one_plus_x_lut.get());
135 softmax_params.exp_lut = raw_exp_lut.get();
136 softmax_params.one_by_one_lut = one_over_one_plus_x_lut.get();