ONE - On-device Neural Engine
Loading...
Searching...
No Matches
SoftMax.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __NNFW_CKER_TRAIN_SOFTMAX_H__
18#define __NNFW_CKER_TRAIN_SOFTMAX_H__
19
20#include "cker/Shape.h"
21#include "cker/eigen/Utils.h"
22
23namespace nnfw
24{
25namespace cker
26{
27namespace train
28{
29
30inline void SoftMaxGrad(const Shape &output_shape, const float *output_data,
31 const Shape &incoming_shape, const float *incoming_data,
32 const Shape &grad_shape, float *grad_data)
33{
34 // TODO Support 4dim softmax gradient
35 assert(incoming_shape.DimensionsCount() == 2);
36 MatchingFlatSize(output_shape, incoming_shape, grad_shape);
37
38 const int batches = incoming_shape.Dims(0);
39 const int width = incoming_shape.Dims(1);
40
41 for (int b = 0; b < batches; ++b)
42 {
43 int b_offset = b * width;
44 for (int w1 = 0; w1 < width; ++w1)
45 {
46 float sum = 0.0f;
47 for (int w2 = 0; w2 < width; ++w2)
48 {
49 float val;
50 if (w1 == w2)
51 {
52 val = output_data[b_offset + w2] * (1.f - output_data[b_offset + w2]);
53 }
54 else
55 {
56 val = -output_data[b_offset + w2] * output_data[b_offset + w1];
57 }
58 val *= incoming_data[b_offset + w2];
59 sum += val;
60 }
61 grad_data[b_offset + w1] = sum;
62 }
63 }
64}
65
66} // namespace train
67} // namespace cker
68} // namespace nnfw
69
70#endif // __NNFW_CKER_TRAIN_SOFTMAX_H__
int32_t DimensionsCount() const
Definition Shape.h:91
int32_t Dims(int i) const
Definition Shape.h:92
const luci_interpreter::RuntimeShape output_shape
void SoftMaxGrad(const Shape &output_shape, const float *output_data, const Shape &incoming_shape, const float *incoming_data, const Shape &grad_shape, float *grad_data)
Definition SoftMax.h:30
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
Definition Shape.h:297
Definition topk_v2.h:30