ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALLogSoftmax.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2021 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef ONERT_MICRO_EXECUTE_PAL_LOG_SOFTMAX_COMMON_H
19#define ONERT_MICRO_EXECUTE_PAL_LOG_SOFTMAX_COMMON_H
20
21#include "core/OMKernelData.h"
22
23#include <cmath>
24
25namespace onert_micro
26{
27namespace execute
28{
29namespace pal
30{
31
32inline OMStatus LogSoftmax(const core::LogSoftmaxParams &params, const float *input_data,
33 float *output_data)
34{
35 const int outer_size = params.num_rows;
36 const int depth = params.row_size;
37
38 for (int i = 0; i < outer_size; ++i)
39 {
40 // Find max element value which we'll use to ensure numerical stability
41 // taking advantage of the following equality:
42 // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
43 float max = std::numeric_limits<float>::lowest();
44 for (int c = 0; c < depth; ++c)
45 {
46 max = std::max(max, input_data[i * depth + c]);
47 }
48
49 // Compute sum.
50 float sum = 0.f;
51 for (int c = 0; c < depth; ++c)
52 {
53 sum += std::exp(input_data[i * depth + c] - max);
54 }
55
56 // Compute result.
57 const float log_sum = std::log(sum);
58 for (int c = 0; c < depth; ++c)
59 {
60 output_data[i * depth + c] = input_data[i * depth + c] - max - log_sum;
61 }
62 }
63
64 return Ok;
65}
66
67} // namespace pal
68} // namespace execute
69} // namespace onert_micro
70
71#endif // ONERT_MICRO_EXECUTE_PAL_LOG_SOFTMAX_COMMON_H
OMStatus LogSoftmax(const core::LogSoftmaxParams &params, const float *input_data, float *output_data)