ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALSoftmaxCommon.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_SOFTMAX_COMMON_H
19#define LUCI_INTERPRETER_PAL_SOFTMAX_COMMON_H
20
21#include "Params.h"
22
24{
25inline void Softmax(const SoftmaxParams &params, const float *input_data, float *output_data)
26{
27 const int outer_size = params.num_rows;
28 const int depth = params.row_size;
29 const double beta = params.beta;
30
31 for (int i = 0; i < outer_size; ++i)
32 {
33 // Find max element value which we'll use to ensure numerical stability
34 // taking advantage of the following equality:
35 // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
36 float max = std::numeric_limits<float>::lowest();
37 for (int c = 0; c < depth; ++c)
38 {
39 max = std::max(max, input_data[i * depth + c]);
40 }
41
42 // Compute sum.
43 float sum = 0.f;
44 for (int c = 0; c < depth; ++c)
45 {
46 const float exp_c = std::exp((input_data[i * depth + c] - max) * static_cast<float>(beta));
47 output_data[i * depth + c] = exp_c;
48 sum += exp_c;
49 }
50
51 // Compute result.
52 for (int c = 0; c < depth; ++c)
53 {
54 output_data[i * depth + c] = output_data[i * depth + c] / sum;
55 }
56 }
57}
58
59} // namespace luci_interpreter_pal
60
61#endif // LUCI_INTERPRETER_PAL_SOFTMAX_COMMON_H