18#ifndef __NNFW_CKER_SOFTMAX_H__
19#define __NNFW_CKER_SOFTMAX_H__
26#if __aarch64__ && __clang__
27#define TFLITE_SOFTMAX_USE_UINT16_LUT
31#include <fixedpoint/fixedpoint.h>
50 for (
int i = 0; i < outer_size; ++i)
55 float max = std::numeric_limits<float>::lowest();
56 for (
int c = 0; c < depth; ++c)
58 max = std::max(max, input_data[i * depth + c]);
63 for (
int c = 0; c < depth; ++c)
65 sum += std::exp((input_data[i * depth + c] - max) *
static_cast<float>(params.
beta));
69 for (
int c = 0; c < depth; ++c)
71 output_data[i * depth + c] =
72 std::exp((input_data[i * depth + c] - max) *
static_cast<float>(params.
beta)) / sum;
79inline void Softmax(
const float *in,
const int input_size,
const int batch_size,
const float beta,
82 assert(input_size > 0);
85 for (
int b = 0; b < batch_size; b++)
88 float max_coeff = in[0];
89 for (
int i = 1; i < input_size; i++)
91 if (in[i] > max_coeff)
97 for (
int i = 0; i < input_size; i++)
99 out[i] = std::exp((in[i] - max_coeff) * beta);
104 float reciprocal_sum_exp = 1.f / exp_sum;
105 for (
int i = 0; i < input_size; i++)
107 out[i] *= reciprocal_sum_exp;
126 out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.
beta;
128 out_mat = out_mat.array().exp();
130 Eigen::Array<float, 1, Eigen::Dynamic> scale = out_mat.array().colwise().sum().inverse();
131 out_mat.array().rowwise() *= scale;
136 const int32_t prob_rnd =
static_cast<int32_t
>(std::round(prob_rescaled));
137 return prob_rnd + zero_point;
144 return static_cast<int32_t
>(prob_rescaled + 0.5f);
150 const float scale = -input_scale * beta;
151 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
152 for (int32_t val = 0; val <= max_uint8; ++val)
154 table[max_uint8 - val] = expf(scale * val);
158template <
typename In,
typename Out>
166 const int32_t clamp_max = std::numeric_limits<Out>::max();
167 const int32_t clamp_min = std::numeric_limits<Out>::min();
168 for (
int i = 0; i < excluding_last_dim; ++i)
170 int32_t max_val = std::numeric_limits<In>::min();
172 for (
int j = 0; j < last_dim; ++j)
174 max_val = std::max(max_val,
static_cast<int32_t
>(input_data[j]));
177 float sum_exp = 0.0f;
178 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
179 const float *table_offset = ¶ms.
table[max_uint8 - max_val];
181 for (
int j = 0; j < last_dim; ++j)
183 sum_exp += table_offset[input_data[j]];
186 const float inv_sum_exp = 1.0f / (sum_exp * params.
scale);
188 for (
int j = 0; j < last_dim; ++j)
190 const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
191 const int32_t prob_quantized = QuantizeSoftmaxOutput<Out>(prob_rescaled, params.
zero_point);
192 output_data[j] =
static_cast<Out
>(std::max(std::min(clamp_max, prob_quantized), clamp_min));
194 input_data += last_dim;
195 output_data += last_dim;
199#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
201inline uint8x16_t aarch64_lookup_vector(
const uint8x16x4_t table[4], uint8x16_t indices)
204 uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
206 uint8x16_t output2 = vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
208 uint8x16_t output3 = vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
210 uint8x16_t output4 = vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
213 return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
216inline void PopulateSoftmaxUInt8LookupTable(uint8_t *uint8_table1, uint8_t *uint8_table2,
217 float input_scale,
float beta)
219 const float scale = input_scale * beta;
220 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
221 const int32_t max_uint16 = std::numeric_limits<uint16_t>::max();
223 for (int32_t val = 0; val <= max_uint8; ++val)
225 float input_to_exp = scale * (val - max_uint8);
226 int32_t temp =
static_cast<int>(expf(input_to_exp) * max_uint16 + 0.5);
227 temp = std::min(max_uint16, temp);
228 uint8_t part1 = temp >> 8;
229 uint8_t part2 = temp & 0xff;
230 uint8_table1[val] =
static_cast<uint8_t
>(part1);
231 uint8_table2[val] =
static_cast<uint8_t
>(part2);
235inline int FindMaxValue(
int size,
const uint8_t *input_data, uint8_t
offset)
237 int32_t max_val = std::numeric_limits<uint8_t>::min();
240 uint8x16_t max_val_dup = vdupq_n_u8(max_val);
241 uint8x16_t offset_dup = vdupq_n_u8(
offset);
242 for (; j <=
size - 16; j += 16)
244 uint8x16_t input_value = vld1q_u8(input_data + j);
245 input_value = veorq_u8(input_value, offset_dup);
246 max_val_dup = vmaxq_u8(input_value, max_val_dup);
248 max_val = std::max(max_val,
static_cast<int32_t
>(vmaxvq_u8(max_val_dup)));
250 for (; j <
size; ++j)
252 max_val = std::max(max_val,
static_cast<int32_t
>(input_data[j] ^
offset));
260inline void StoreValue(int32x4x4_t value_to_store, int8_t *output)
262 const int16x8_t result_1 =
263 vcombine_s16(vqmovn_s32(value_to_store.val[1]), vqmovn_s32(value_to_store.val[0]));
264 const int16x8_t result_2 =
265 vcombine_s16(vqmovn_s32(value_to_store.val[3]), vqmovn_s32(value_to_store.val[2]));
266 const int8x16_t
result = vcombine_s8(vqmovn_s16(result_2), vqmovn_s16(result_1));
267 vst1q_s8(output, result);
272inline void StoreValue(int32x4x4_t value_to_store, uint8_t *output)
274 const uint16x8_t result_1 =
275 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[1])),
276 vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[0])));
277 const uint16x8_t result_2 =
278 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[3])),
279 vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[2])));
280 const uint8x16_t
result = vcombine_u8(vqmovn_u16(result_2), vqmovn_u16(result_1));
281 vst1q_u8(output, result);
286template <
typename In,
typename Out>
287inline void SoftmaxInt8LUT(
const SoftmaxParams ¶ms,
const Shape &input_shape,
290 const int trailing_dim = input_shape.DimensionsCount() - 1;
294 const int32_t clamp_max = std::numeric_limits<Out>::max();
295 const int32_t clamp_min = std::numeric_limits<Out>::min();
303 if (std::is_same<In, int8_t>::value)
308 const uint8_t *input_data_uint =
reinterpret_cast<const uint8_t *
>(
input_data);
314 uint8x16x4_t table1[4];
315 table1[0] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 0);
316 table1[1] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 1);
317 table1[2] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 2);
318 table1[3] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 3);
320 uint8x16x4_t table2[4];
321 table2[0] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 0);
322 table2[1] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 1);
323 table2[2] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 2);
324 table2[3] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 3);
326 for (
int i = 0; i < excluding_last_dim; ++i)
329 int32_t max_val = FindMaxValue(last_dim, input_data_uint,
offset);
332 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
333 const uint8_t table_offset = max_uint8 - max_val;
337 uint8x16_t table_offset_dup = vdupq_n_u8(table_offset);
338 uint8x16_t offset_dup = vdupq_n_u8(
offset);
339 uint32x4_t sum_4 = vdupq_n_u32(0);
340 const int multiplier_shift = 8;
341 for (; sum_j <= last_dim - 16; sum_j += 16)
343 uint8x16_t input_value = vld1q_u8(input_data_uint + sum_j);
344 input_value = veorq_u8(input_value, offset_dup);
345 input_value = vaddq_u8(input_value, table_offset_dup);
347 const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
348 const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
350 uint16x8_t exp_value1 = vshll_n_u8(vget_high_u8(output1), multiplier_shift);
351 uint16x8_t exp_value2 = vshll_n_u8(vget_low_u8(output1), multiplier_shift);
353 exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
354 exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
356 sum_4 = vpadalq_u16(sum_4, exp_value1);
357 sum_4 = vpadalq_u16(sum_4, exp_value2);
359 int temp = vgetq_lane_u32(sum_4, 0) + vgetq_lane_u32(sum_4, 1) + vgetq_lane_u32(sum_4, 2) +
360 vgetq_lane_u32(sum_4, 3);
363 for (; sum_j < last_dim; ++sum_j)
365 const uint8_t
index = (input_data_uint[sum_j] ^
offset) + table_offset;
367 uint8_t part1 = params.uint8_table1[
index];
368 uint8_t part2 = params.uint8_table2[
index];
369 sum_exp += ((part1 << 8) + part2);
372 const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
374 int32_t multiplier,
shift;
379 const int32x4_t output_zp_dup = vdupq_n_s32(params.zero_point);
380 const int32x4_t max_val_dup = vdupq_n_s32(clamp_max);
381 const int32x4_t min_val_dup = vdupq_n_s32(clamp_min);
383 for (; j <= last_dim - 16; j += 16)
385 uint8x16_t input_value = vld1q_u8(input_data_uint + j);
386 input_value = veorq_u8(input_value, offset_dup);
387 input_value = vaddq_u8(input_value, table_offset_dup);
389 const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
390 const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
392 uint16x8_t exp_value1 = vshll_n_u8(vget_high_u8(output1), multiplier_shift);
393 uint16x8_t exp_value2 = vshll_n_u8(vget_low_u8(output1), multiplier_shift);
395 exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
396 exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
398 int32x4x4_t output_value;
399 output_value.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value1)));
400 output_value.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value1)));
401 output_value.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value2)));
402 output_value.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value2)));
404 int32x4x4_t temp_val = MultiplyByQuantizedMultiplier4Rows(output_value, multiplier, shift);
406 temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
407 temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
408 temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
409 temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
411 temp_val.val[0] = vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
412 temp_val.val[1] = vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
413 temp_val.val[2] = vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
414 temp_val.val[3] = vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
416 StoreValue(temp_val, output_data + j);
418 for (; j < last_dim; ++j)
420 const uint8_t
index = (input_data_uint[j] ^
offset) + table_offset;
421 const uint8_t part1 = params.uint8_table1[
index];
422 const uint8_t part2 = params.uint8_table2[
index];
423 const int32_t exp_value = (part1 << 8) + part2;
427 std::max(std::min(clamp_max, output_value + params.zero_point), clamp_min));
429 input_data_uint += last_dim;
int32_t DimensionsCount() const
__global uchar * offset(const Image *img, int x, int y)
const luci_interpreter::RuntimeShape output_shape
Index shift(const Index &in_index, const Shape &shift_from)
loco::GraphInputIndex index(const TFPlaceholder *node)
void Softmax(const SoftmaxParams ¶ms, const Shape &input_shape, const float *input_data, const Shape &output_shape, float *output_data)
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
int32_t QuantizeSoftmaxOutput< uint8_t >(float prob_rescaled, int32_t)
MatrixMap< Scalar > MapAsMatrixWithLastDimAsRows(Scalar *data, const Shape &shape)
int32_t QuantizeSoftmaxOutput(float prob_rescaled, int32_t zero_point)
int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
void Softmax(const float *in, const int input_size, const int batch_size, const float beta, float *out)
int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
void PopulateSoftmaxLookupTable(float *table, float input_scale, float beta)