19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/common.h>
43 (
gamma()->shape().dim(0) == 1));
50 switch (
input()->element_type())
52 case DataType::FLOAT32:
56 throw std::runtime_error(
"luci-intp RmsNorm Unsupported type.");
60void RmsNorm::evalFloat()
const
78 for (int32_t batch = 0; batch <
batches; batch++)
80 for (int32_t height = 0; height <
heights; height++)
82 for (int32_t width = 0; width <
widths; width++)
85 for (int32_t channel = 0; channel <
channels; channel++)
88 input_data[tflite::Offset(
input_shape, batch, height, width, channel)];
92 for (int32_t channel = 0; channel <
channels; channel++)
95 output_data[tflite::Offset(
output_shape, batch, height, width, channel)] =
97 (input_data[tflite::Offset(
input_shape, batch, height, width, channel)] /
rms);
109 for (int32_t batch = 0; batch <
batches; batch++)
111 for (int32_t channel = 0; channel <
channels; channel++)
115 static_cast<size_t>(batch *
channels *
size) +
static_cast<size_t>(channel *
size);
116 for (int32_t
i = 0;
i <
size;
i++)
122 for (int32_t
i = 0;
i <
size;
i++)
131 throw std::runtime_error(
"luci-intp RmsNorm unsupported rank.");
const RmsNormParams & params() const
void resize(const Shape &new_shape)
const Shape & shape() const
RmsNorm(const Tensor *input, const Tensor *gamma, Tensor *output, const RmsNormParams ¶ms)
const Tensor * gamma() const
const Tensor * input() const
void execute() const override
void configure() override
#define LUCI_INTERPRETER_CHECK(cond)
__global uchar * offset(const Image *img, int x, int y)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
T must_cast(loco::Node *node)