45 for (int32_t batch = 0; batch < batches; batch++)
47 for (int32_t height = 0; height < heights; height++)
49 for (int32_t width = 0; width < widths; width++)
52 double square_sum = 0.0f;
53 for (int32_t channel = 0; channel < channels; channel++)
55 double input_val = input_data[
Offset(input_shape, batch, height, width, channel)];
56 square_sum += (input_val * input_val);
58 double rms = std::sqrt((square_sum / channels) + params.
epsilon);
59 for (int32_t channel = 0; channel < channels; channel++)
61 double gamma = (single_gamma ? gamma_data[0] : gamma_data[channel]);
63 gamma * (input_data[
Offset(input_shape, batch, height, width, channel)] / rms);
75 for (int32_t height = 0; height < heights; height++)
77 for (int32_t width = 0; width < widths; width++)
80 double square_sum = 0.0f;
81 for (int32_t channel = 0; channel < channels; channel++)
83 double input_val = input_data[(height * widths + width) * channels + channel];
84 square_sum += (input_val * input_val);
86 double rms = std::sqrt((square_sum / channels) + params.
epsilon);
87 for (int32_t channel = 0; channel < channels; channel++)
89 double gamma = (single_gamma ? gamma_data[0] : gamma_data[channel]);
90 output_data[(height * widths + width) * channels + channel] =
91 gamma * (input_data[(height * widths + width) * channels + channel] / rms);
98 throw std::runtime_error(
"cker::RmsNorm: Unsupported input shape");
void RmsNorm(const RmsNormParams ¶ms, const Shape &input_shape, const float *input_data, const Shape &gamma_shape, const float *gamma_data, const Shape &output_shape, float *output_data)