17#include "kernels/InstanceNorm.h"
19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/common.h>
39 if (
input()->shape().num_dims() == 4)
44 gamma()->shape().dim(0) == 1);
48 beta()->shape().dim(0) == 1);
50 else if (
input()->shape().num_dims() == 3)
57 gamma()->shape().dim(0) == 1);
61 beta()->shape().dim(0) == 1);
71 switch (
input()->element_type())
73 case DataType::FLOAT32:
77 throw std::runtime_error(
"luci-intp InstanceNorm Unsupported type.");
81void InstanceNorm::evalFloat()
const
83 float activation_min, activation_max;
104 for (int32_t batch = 0; batch <
batches; batch++)
106 for (int32_t channel = 0; channel <
channels; channel++)
111 for (int32_t height = 0; height <
heights; height++)
113 for (int32_t width = 0; width <
widths; width++)
116 input_data[tflite::Offset(
input_shape, batch, height, width, channel)];
121 double mean = sum /
size;
126 double a =
gamma / (std::sqrt(var +
params().epsilon));
127 double b = -mean * a +
beta;
129 for (int32_t height = 0; height <
heights; height++)
131 for (int32_t width = 0; width <
widths; width++)
134 input_data[tflite::Offset(
output_shape, batch, height, width, channel)];
136 output_data[tflite::Offset(
output_shape, batch, height, width, channel)] =
137 tflite::ActivationFunctionWithMinMax((
float)
output_value, activation_min,
150 for (int32_t batch = 0; batch <
batches; batch++)
152 for (int32_t channel = 0; channel <
channels; channel++)
157 static_cast<size_t>(batch *
channels *
size) +
static_cast<size_t>(channel *
size);
158 for (int32_t
i = 0;
i <
size;
i++)
169 double a =
gamma / (std::sqrt(var +
params().epsilon));
170 double b = -mean * a +
beta;
172 for (int32_t
i = 0;
i <
size;
i++)
183 throw std::runtime_error(
"luci-intp InstanceNorm unsupported rank.");
const InstanceNormParams & params() const
void resize(const Shape &new_shape)
void execute() const override
const Tensor * beta() const
const Tensor * input() const
void configure() override
const Tensor * gamma() const
InstanceNorm(const Tensor *input, const Tensor *gamma, const Tensor *beta, Tensor *output, const InstanceNormParams ¶ms)
#define LUCI_INTERPRETER_CHECK(cond)
__global uchar * offset(const Image *img, int x, int y)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void calculateActivationRange(Activation activation, T *activation_min, T *activation_max)
T must_cast(loco::Node *node)