17#include "kernels/InstanceNorm.h"
19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/common.h>
30 Tensor *output,
const InstanceNormParams ¶ms)
31 : KernelWithParams<InstanceNormParams>({
input, gamma, beta}, {
output}, params)
35void InstanceNorm::configure()
42 gamma()->shape().dim(0) == 1);
46 beta()->shape().dim(0) == 1);
51void InstanceNorm::execute()
const
53 switch (
input()->element_type())
55 case DataType::FLOAT32:
59 assert(
false &&
"Unsupported type.");
63void InstanceNorm::evalFloat()
const
65 float activation_min, activation_max;
69 const int32_t batches = tflite::MatchingDim(input_shape, 0,
output_shape, 0);
70 const int32_t heights = tflite::MatchingDim(input_shape, 1,
output_shape, 1);
71 const int32_t widths = tflite::MatchingDim(input_shape, 2,
output_shape, 2);
72 const int32_t channels = tflite::MatchingDim(input_shape, 3,
output_shape, 3);
74 const float *gamma_data = getTensorData<float>(gamma());
76 bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1;
77 const float *beta_data = getTensorData<float>(beta());
79 bool single_beta = beta_shape.DimensionsCount() == 1 && beta_shape.Dims(0) == 1;
81 for (int32_t batch = 0; batch < batches; batch++)
83 for (int32_t channel = 0; channel < channels; channel++)
86 double square_sum = 0.0f;
87 int32_t
size = heights * widths;
88 for (int32_t height = 0; height < heights; height++)
90 for (int32_t width = 0; width < widths; width++)
92 double input_val =
input_data[tflite::Offset(input_shape, batch, height, width, channel)];
94 square_sum += (input_val * input_val);
98 double var = square_sum /
size - mean * mean;
100 double gamma = single_gamma ? gamma_data[0] : gamma_data[channel];
101 double beta = single_beta ? beta_data[0] : beta_data[channel];
102 double a = gamma / (std::sqrt(var + params().epsilon));
103 double b = -mean * a + beta;
105 for (int32_t height = 0; height < heights; height++)
107 for (int32_t width = 0; width < widths; width++)
111 double output_value = input_value * a +
b;
113 tflite::ActivationFunctionWithMinMax((
float)output_value, activation_min,
InstanceNorm(const Tensor *input, const Tensor *gamma, const Tensor *beta, Tensor *output, const InstanceNormParams ¶ms)
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void calculateActivationRange(Activation activation, T *activation_min, T *activation_max)