19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/common.h>
43 (
gamma()->shape().dim(0) == 1));
50 switch (
input()->element_type())
52 case DataType::FLOAT32:
56 throw std::runtime_error(
"luci-intp RmsNorm Unsupported type.");
60void RmsNorm::evalFloat()
const
65 const float *input_data = getTensorData<float>(
input());
66 const float *gamma_data = getTensorData<float>(
gamma());
68 bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1;
69 float *output_data = getTensorData<float>(
output());
71 if (input_shape.DimensionsCount() == 4)
74 const int32_t batches = tflite::MatchingDim(input_shape, 0,
output_shape, 0);
75 const int32_t heights = tflite::MatchingDim(input_shape, 1,
output_shape, 1);
76 const int32_t widths = tflite::MatchingDim(input_shape, 2,
output_shape, 2);
77 const int32_t channels = tflite::MatchingDim(input_shape, 3,
output_shape, 3);
78 for (int32_t batch = 0; batch < batches; batch++)
80 for (int32_t height = 0; height < heights; height++)
82 for (int32_t width = 0; width < widths; width++)
84 double square_sum = 0.0f;
85 for (int32_t channel = 0; channel < channels; channel++)
88 input_data[tflite::Offset(input_shape, batch, height, width, channel)];
89 square_sum += (input_val * input_val);
91 double rms = std::sqrt((square_sum / channels) +
params().epsilon);
92 for (int32_t channel = 0; channel < channels; channel++)
94 double gamma = single_gamma ? gamma_data[0] : gamma_data[channel];
95 output_data[tflite::Offset(
output_shape, batch, height, width, channel)] =
97 (input_data[tflite::Offset(input_shape, batch, height, width, channel)] / rms);
103 else if (input_shape.DimensionsCount() == 3)
106 const int32_t batches = tflite::MatchingDim(input_shape, 0,
output_shape, 0);
107 const int32_t channels = tflite::MatchingDim(input_shape, 1,
output_shape, 1);
109 for (int32_t batch = 0; batch < batches; batch++)
111 for (int32_t channel = 0; channel < channels; channel++)
113 double square_sum = 0.0f;
115 static_cast<size_t>(batch * channels *
size) +
static_cast<size_t>(channel *
size);
116 for (int32_t i = 0; i <
size; i++)
119 square_sum += (input_val * input_val);
121 double rms = std::sqrt((square_sum /
size) +
params().epsilon);
122 for (int32_t i = 0; i <
size; i++)
124 double gamma = single_gamma ? gamma_data[0] : gamma_data[i];
131 throw std::runtime_error(
"luci-intp RmsNorm unsupported rank.");
const RmsNormParams & params() const
void resize(const Shape &new_shape)
const Shape & shape() const
RmsNorm(const Tensor *input, const Tensor *gamma, Tensor *output, const RmsNormParams ¶ms)
const Tensor * gamma() const
const Tensor * input() const
void execute() const override
void configure() override
#define LUCI_INTERPRETER_CHECK(cond)
__global uchar * offset(const Image *img, int x, int y)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)