ONE - On-device Neural Engine
Loading...
Searching...
No Matches
RmsNorm.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __NNFW_CKER_RMS_NORM_H__
18#define __NNFW_CKER_RMS_NORM_H__
19
20#include "cker/Shape.h"
21#include "cker/Types.h"
22#include "cker/Utils.h"
23
24#include <cmath>
25#include <stdexcept>
26
27namespace nnfw
28{
29namespace cker
30{
31
32inline void RmsNorm(const RmsNormParams &params, const Shape &input_shape, const float *input_data,
33 const Shape &gamma_shape, const float *gamma_data, const Shape &output_shape,
34 float *output_data)
35{
36 bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1;
37
38 if (input_shape.DimensionsCount() == 4)
39 {
40 const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
41 const int32_t heights = MatchingDim(input_shape, 1, output_shape, 1);
42 const int32_t widths = MatchingDim(input_shape, 2, output_shape, 2);
43 const int32_t channels = MatchingDim(input_shape, 3, output_shape, 3);
44
45 for (int32_t batch = 0; batch < batches; batch++)
46 {
47 for (int32_t height = 0; height < heights; height++)
48 {
49 for (int32_t width = 0; width < widths; width++)
50 {
51 // normalize over last-axis
52 double square_sum = 0.0f;
53 for (int32_t channel = 0; channel < channels; channel++)
54 {
55 double input_val = input_data[Offset(input_shape, batch, height, width, channel)];
56 square_sum += (input_val * input_val);
57 }
58 double rms = std::sqrt((square_sum / channels) + params.epsilon);
59 for (int32_t channel = 0; channel < channels; channel++)
60 {
61 double gamma = (single_gamma ? gamma_data[0] : gamma_data[channel]);
62 output_data[Offset(output_shape, batch, height, width, channel)] =
63 gamma * (input_data[Offset(input_shape, batch, height, width, channel)] / rms);
64 }
65 }
66 }
67 }
68 }
69 else if (input_shape.DimensionsCount() == 3)
70 {
71 const int32_t heights = MatchingDim(input_shape, 0, output_shape, 0);
72 const int32_t widths = MatchingDim(input_shape, 1, output_shape, 1);
73 const int32_t channels = MatchingDim(input_shape, 2, output_shape, 2);
74
75 for (int32_t height = 0; height < heights; height++)
76 {
77 for (int32_t width = 0; width < widths; width++)
78 {
79 // normalize over last-axis
80 double square_sum = 0.0f;
81 for (int32_t channel = 0; channel < channels; channel++)
82 {
83 double input_val = input_data[(height * widths + width) * channels + channel];
84 square_sum += (input_val * input_val);
85 }
86 double rms = std::sqrt((square_sum / channels) + params.epsilon);
87 for (int32_t channel = 0; channel < channels; channel++)
88 {
89 double gamma = (single_gamma ? gamma_data[0] : gamma_data[channel]);
90 output_data[(height * widths + width) * channels + channel] =
91 gamma * (input_data[(height * widths + width) * channels + channel] / rms);
92 }
93 }
94 }
95 }
96 else
97 {
98 throw std::runtime_error("cker::RmsNorm: Unsupported input shape");
99 }
100}
101
102} // namespace cker
103} // namespace nnfw
104
105#endif // __NNFW_CKER_RMS_NORM_H__
int32_t DimensionsCount() const
Definition Shape.h:91
int32_t Dims(int i) const
Definition Shape.h:92
const luci_interpreter::RuntimeShape output_shape
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
Definition Shape.h:220
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
Definition Shape.h:237
void RmsNorm(const RmsNormParams &params, const Shape &input_shape, const float *input_data, const Shape &gamma_shape, const float *gamma_data, const Shape &output_shape, float *output_data)
Definition RmsNorm.h:32
Definition topk_v2.h:30