ONE - On-device Neural Engine
Loading...
Searching...
No Matches
SoftMax.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_SOFTMAX_H__
19#define __NNFW_CKER_SOFTMAX_H__
20
21#include "cker/Shape.h"
22#include "cker/Utils.h"
23#include "cker/Types.h"
24#include "cker/eigen/Utils.h"
25
26#if __aarch64__ && __clang__
27#define TFLITE_SOFTMAX_USE_UINT16_LUT
28#endif
29
30#include <Eigen/Core>
31#include <fixedpoint/fixedpoint.h>
32#include <cmath>
33
34namespace nnfw
35{
36namespace cker
37{
38
39namespace reference
40{
41
42// Note. This Softmax function supports all of dimensions
43inline void Softmax(const SoftmaxParams &params, const Shape &input_shape, const float *input_data,
44 const Shape &output_shape, float *output_data)
45{
46 const int trailing_dim = input_shape.DimensionsCount() - 1;
47 const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
48 const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
49
50 for (int i = 0; i < outer_size; ++i)
51 {
52 // Find max element value which we'll use to ensure numerical stability
53 // taking advantage of the following equality:
54 // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
55 float max = std::numeric_limits<float>::lowest();
56 for (int c = 0; c < depth; ++c)
57 {
58 max = std::max(max, input_data[i * depth + c]);
59 }
60
61 // Compute sum.
62 float sum = 0.f;
63 for (int c = 0; c < depth; ++c)
64 {
65 sum += std::exp((input_data[i * depth + c] - max) * static_cast<float>(params.beta));
66 }
67
68 // Compute result.
69 for (int c = 0; c < depth; ++c)
70 {
71 output_data[i * depth + c] =
72 std::exp((input_data[i * depth + c] - max) * static_cast<float>(params.beta)) / sum;
73 }
74 }
75}
76} // namespace reference
77
78// Performs softmax along the input of size (input_size * batch_size).
79inline void Softmax(const float *in, const int input_size, const int batch_size, const float beta,
80 float *out)
81{
82 assert(input_size > 0);
83
84 // For each batch
85 for (int b = 0; b < batch_size; b++)
86 {
87 // Find the max coeff.
88 float max_coeff = in[0];
89 for (int i = 1; i < input_size; i++)
90 {
91 if (in[i] > max_coeff)
92 max_coeff = in[i];
93 }
94
95 // Compute the normalized sum of exps.
96 float exp_sum = 0.0;
97 for (int i = 0; i < input_size; i++)
98 {
99 out[i] = std::exp((in[i] - max_coeff) * beta);
100 exp_sum += out[i];
101 }
102
103 // Divide by the sum of exps.
104 float reciprocal_sum_exp = 1.f / exp_sum;
105 for (int i = 0; i < input_size; i++)
106 {
107 out[i] *= reciprocal_sum_exp;
108 }
109
110 // Advance in and out pointers for the next batch.
111 in += input_size;
112 out += input_size;
113 }
114}
115
116inline void Softmax(const SoftmaxParams &params, const Shape &input_shape, const float *input_data,
117 const Shape &output_shape, float *output_data)
118{
119 // Validate whether if shapes of input and output are the same
120 MatchingFlatSize(input_shape, output_shape);
121
122 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
123 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
124 // Compute the exponential first, removing the max coefficient for numerical
125 // stability.
126 out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
127 // We are separating out the exp function so that exp can be vectorized.
128 out_mat = out_mat.array().exp();
129 // Normalize to get the activations.
130 Eigen::Array<float, 1, Eigen::Dynamic> scale = out_mat.array().colwise().sum().inverse();
131 out_mat.array().rowwise() *= scale;
132}
133
134template <typename T> inline int32_t QuantizeSoftmaxOutput(float prob_rescaled, int32_t zero_point)
135{
136 const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
137 return prob_rnd + zero_point;
138}
139
140#if !__aarch64__
141// With ARM64, rounding is faster than add + truncation.
142template <> inline int32_t QuantizeSoftmaxOutput<uint8_t>(float prob_rescaled, int32_t)
143{
144 return static_cast<int32_t>(prob_rescaled + 0.5f);
145}
146#endif
147
148inline void PopulateSoftmaxLookupTable(float *table, float input_scale, float beta)
149{
150 const float scale = -input_scale * beta;
151 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
152 for (int32_t val = 0; val <= max_uint8; ++val)
153 {
154 table[max_uint8 - val] = expf(scale * val);
155 }
156}
157
158template <typename In, typename Out>
159inline void Softmax(const SoftmaxParams &params, const Shape &input_shape, const In *input_data,
160 const Shape &output_shape, Out *output_data)
161{
162 const int trailing_dim = input_shape.DimensionsCount() - 1;
163 const int excluding_last_dim = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
164 const int last_dim = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
165
166 const int32_t clamp_max = std::numeric_limits<Out>::max();
167 const int32_t clamp_min = std::numeric_limits<Out>::min();
168 for (int i = 0; i < excluding_last_dim; ++i)
169 {
170 int32_t max_val = std::numeric_limits<In>::min();
171 // Find max quantized value.
172 for (int j = 0; j < last_dim; ++j)
173 {
174 max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
175 }
176
177 float sum_exp = 0.0f;
178 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
179 const float *table_offset = &params.table[max_uint8 - max_val];
180 // Calculate normalizer sum(exp(x)).
181 for (int j = 0; j < last_dim; ++j)
182 {
183 sum_exp += table_offset[input_data[j]];
184 }
185
186 const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
187 // Normalize and quantize probabilities.
188 for (int j = 0; j < last_dim; ++j)
189 {
190 const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
191 const int32_t prob_quantized = QuantizeSoftmaxOutput<Out>(prob_rescaled, params.zero_point);
192 output_data[j] = static_cast<Out>(std::max(std::min(clamp_max, prob_quantized), clamp_min));
193 }
194 input_data += last_dim;
195 output_data += last_dim;
196 }
197}
198
199#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
200// Looks up each element of <indices> in <table>, returns them in a vector.
201inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4], uint8x16_t indices)
202{
203 // Look up in 1st quarter of the table: top 2 bits of indices == 00
204 uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
205 // Look up in 2nd quarter of the table: top 2 bits of indices == 01
206 uint8x16_t output2 = vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
207 // Look up in 3rd quarter of the table: top 2 bits of indices == 10
208 uint8x16_t output3 = vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
209 // Look up in 4th quarter of the table: top 2 bits of indices == 11
210 uint8x16_t output4 = vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
211
212 // Combine result of the 4 lookups.
213 return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
214}
215
216inline void PopulateSoftmaxUInt8LookupTable(uint8_t *uint8_table1, uint8_t *uint8_table2,
217 float input_scale, float beta)
218{
219 const float scale = input_scale * beta;
220 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
221 const int32_t max_uint16 = std::numeric_limits<uint16_t>::max();
222
223 for (int32_t val = 0; val <= max_uint8; ++val)
224 {
225 float input_to_exp = scale * (val - max_uint8);
226 int32_t temp = static_cast<int>(expf(input_to_exp) * max_uint16 + 0.5);
227 temp = std::min(max_uint16, temp);
228 uint8_t part1 = temp >> 8;
229 uint8_t part2 = temp & 0xff;
230 uint8_table1[val] = static_cast<uint8_t>(part1);
231 uint8_table2[val] = static_cast<uint8_t>(part2);
232 }
233}
234
235inline int FindMaxValue(int size, const uint8_t *input_data, uint8_t offset)
236{
237 int32_t max_val = std::numeric_limits<uint8_t>::min();
238 int j = 0;
239
240 uint8x16_t max_val_dup = vdupq_n_u8(max_val);
241 uint8x16_t offset_dup = vdupq_n_u8(offset);
242 for (; j <= size - 16; j += 16)
243 {
244 uint8x16_t input_value = vld1q_u8(input_data + j);
245 input_value = veorq_u8(input_value, offset_dup);
246 max_val_dup = vmaxq_u8(input_value, max_val_dup);
247 }
248 max_val = std::max(max_val, static_cast<int32_t>(vmaxvq_u8(max_val_dup)));
249
250 for (; j < size; ++j)
251 {
252 max_val = std::max(max_val, static_cast<int32_t>(input_data[j] ^ offset));
253 }
254 return max_val;
255}
256
257#ifdef USE_NEON
258// Value_to_store layout:
259// [high_high, high_low, low_high, low_low].
260inline void StoreValue(int32x4x4_t value_to_store, int8_t *output)
261{
262 const int16x8_t result_1 =
263 vcombine_s16(vqmovn_s32(value_to_store.val[1]), vqmovn_s32(value_to_store.val[0]));
264 const int16x8_t result_2 =
265 vcombine_s16(vqmovn_s32(value_to_store.val[3]), vqmovn_s32(value_to_store.val[2]));
266 const int8x16_t result = vcombine_s8(vqmovn_s16(result_2), vqmovn_s16(result_1));
267 vst1q_s8(output, result);
268}
269
270// Value_to_store layout:
271// [high_high, high_low, low_high, low_low].
272inline void StoreValue(int32x4x4_t value_to_store, uint8_t *output)
273{
274 const uint16x8_t result_1 =
275 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[1])),
276 vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[0])));
277 const uint16x8_t result_2 =
278 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[3])),
279 vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[2])));
280 const uint8x16_t result = vcombine_u8(vqmovn_u16(result_2), vqmovn_u16(result_1));
281 vst1q_u8(output, result);
282}
283
284#endif
285
286template <typename In, typename Out>
287inline void SoftmaxInt8LUT(const SoftmaxParams &params, const Shape &input_shape,
288 const In *input_data, const Shape &output_shape, Out *output_data)
289{
290 const int trailing_dim = input_shape.DimensionsCount() - 1;
291 const int excluding_last_dim = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
292 const int last_dim = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
293
294 const int32_t clamp_max = std::numeric_limits<Out>::max();
295 const int32_t clamp_min = std::numeric_limits<Out>::min();
296
297 // Offset is used to interpret the input data "correctly".
298 // If the input is uint8, the data will be unchanged.
299 // If the input is int8, since it will be reinterpret as uint8.
300 // e.g.,
301 // int8 127 will be applied "offset" to become 255 in uint8.
302 uint8_t offset = 0;
303 if (std::is_same<In, int8_t>::value)
304 {
305 offset = 0x80;
306 }
307
308 const uint8_t *input_data_uint = reinterpret_cast<const uint8_t *>(input_data);
309
310 // This code uses ARM64-only instructions.
311 // TODO(b/143709993): Port to ARMv7
312
313 // Load the tables into registers. (4*4 128-bit registers)
314 uint8x16x4_t table1[4];
315 table1[0] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 0);
316 table1[1] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 1);
317 table1[2] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 2);
318 table1[3] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 3);
319
320 uint8x16x4_t table2[4];
321 table2[0] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 0);
322 table2[1] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 1);
323 table2[2] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 2);
324 table2[3] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 3);
325
326 for (int i = 0; i < excluding_last_dim; ++i)
327 {
328 // Find max quantized value.
329 int32_t max_val = FindMaxValue(last_dim, input_data_uint, offset);
330
331 int32_t sum_exp = 0;
332 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
333 const uint8_t table_offset = max_uint8 - max_val;
334
335 // Calculate normalizer sum(exp(x)).
336 int sum_j = 0;
337 uint8x16_t table_offset_dup = vdupq_n_u8(table_offset);
338 uint8x16_t offset_dup = vdupq_n_u8(offset);
339 uint32x4_t sum_4 = vdupq_n_u32(0);
340 const int multiplier_shift = 8;
341 for (; sum_j <= last_dim - 16; sum_j += 16)
342 {
343 uint8x16_t input_value = vld1q_u8(input_data_uint + sum_j);
344 input_value = veorq_u8(input_value, offset_dup);
345 input_value = vaddq_u8(input_value, table_offset_dup);
346
347 const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
348 const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
349
350 uint16x8_t exp_value1 = vshll_n_u8(vget_high_u8(output1), multiplier_shift);
351 uint16x8_t exp_value2 = vshll_n_u8(vget_low_u8(output1), multiplier_shift);
352
353 exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
354 exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
355
356 sum_4 = vpadalq_u16(sum_4, exp_value1);
357 sum_4 = vpadalq_u16(sum_4, exp_value2);
358 }
359 int temp = vgetq_lane_u32(sum_4, 0) + vgetq_lane_u32(sum_4, 1) + vgetq_lane_u32(sum_4, 2) +
360 vgetq_lane_u32(sum_4, 3);
361 sum_exp += temp;
362
363 for (; sum_j < last_dim; ++sum_j)
364 {
365 const uint8_t index = (input_data_uint[sum_j] ^ offset) + table_offset;
366
367 uint8_t part1 = params.uint8_table1[index];
368 uint8_t part2 = params.uint8_table2[index];
369 sum_exp += ((part1 << 8) + part2);
370 }
371
372 const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
373
374 int32_t multiplier, shift;
375 QuantizeMultiplier(inv_sum_exp, &multiplier, &shift);
376
377 // Normalize and quantize probabilities.
378 int j = 0;
379 const int32x4_t output_zp_dup = vdupq_n_s32(params.zero_point);
380 const int32x4_t max_val_dup = vdupq_n_s32(clamp_max);
381 const int32x4_t min_val_dup = vdupq_n_s32(clamp_min);
382
383 for (; j <= last_dim - 16; j += 16)
384 {
385 uint8x16_t input_value = vld1q_u8(input_data_uint + j);
386 input_value = veorq_u8(input_value, offset_dup);
387 input_value = vaddq_u8(input_value, table_offset_dup);
388
389 const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
390 const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
391
392 uint16x8_t exp_value1 = vshll_n_u8(vget_high_u8(output1), multiplier_shift);
393 uint16x8_t exp_value2 = vshll_n_u8(vget_low_u8(output1), multiplier_shift);
394
395 exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
396 exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
397
398 int32x4x4_t output_value;
399 output_value.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value1)));
400 output_value.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value1)));
401 output_value.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value2)));
402 output_value.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value2)));
403
404 int32x4x4_t temp_val = MultiplyByQuantizedMultiplier4Rows(output_value, multiplier, shift);
405
406 temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
407 temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
408 temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
409 temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
410
411 temp_val.val[0] = vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
412 temp_val.val[1] = vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
413 temp_val.val[2] = vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
414 temp_val.val[3] = vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
415
416 StoreValue(temp_val, output_data + j);
417 }
418 for (; j < last_dim; ++j)
419 {
420 const uint8_t index = (input_data_uint[j] ^ offset) + table_offset;
421 const uint8_t part1 = params.uint8_table1[index];
422 const uint8_t part2 = params.uint8_table2[index];
423 const int32_t exp_value = (part1 << 8) + part2;
424 const int32_t output_value = MultiplyByQuantizedMultiplier(exp_value, multiplier, shift);
425
426 output_data[j] = static_cast<Out>(
427 std::max(std::min(clamp_max, output_value + params.zero_point), clamp_min));
428 }
429 input_data_uint += last_dim;
430 output_data += last_dim;
431 }
432}
433#endif
434
435} // namespace cker
436} // namespace nnfw
437
438#endif // __NNFW_CKER_SOFTMAX_H__
int32_t DimensionsCount() const
Definition Shape.h:91
__global uchar * offset(const Image *img, int x, int y)
Definition helpers.h:540
const luci_interpreter::RuntimeShape output_shape
result
Definition infer.py:103
list input_data
Definition infer.py:29
Index shift(const Index &in_index, const Shape &shift_from)
Definition Common.cpp:26
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
void Softmax(const SoftmaxParams &params, const Shape &input_shape, const float *input_data, const Shape &output_shape, float *output_data)
Definition SoftMax.h:43
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
Definition Shape.h:220
int32_t QuantizeSoftmaxOutput< uint8_t >(float prob_rescaled, int32_t)
Definition SoftMax.h:142
MatrixMap< Scalar > MapAsMatrixWithLastDimAsRows(Scalar *data, const Shape &shape)
Definition Utils.h:60
int32_t QuantizeSoftmaxOutput(float prob_rescaled, int32_t zero_point)
Definition SoftMax.h:134
int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
Definition Shape.h:304
void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
Definition Utils.h:48
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
Definition Shape.h:297
void Softmax(const float *in, const int input_size, const int batch_size, const float beta, float *out)
Definition SoftMax.h:79
int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
Definition Utils.h:96
void PopulateSoftmaxLookupTable(float *table, float input_scale, float beta)
Definition SoftMax.h:148
Definition topk_v2.h:30
int32_t size[5]
Definition Slice.cpp:35
Definition Shape.h:28