17#include "kernels/InstanceNorm.h"
19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/common.h>
39 if (
input()->shape().num_dims() == 4)
44 gamma()->shape().dim(0) == 1);
48 beta()->shape().dim(0) == 1);
50 else if (
input()->shape().num_dims() == 3)
57 gamma()->shape().dim(0) == 1);
61 beta()->shape().dim(0) == 1);
71 switch (
input()->element_type())
73 case DataType::FLOAT32:
77 throw std::runtime_error(
"luci-intp InstanceNorm Unsupported type.");
81void InstanceNorm::evalFloat()
const
83 float activation_min, activation_max;
88 const float *input_data = getTensorData<float>(
input());
89 const float *gamma_data = getTensorData<float>(
gamma());
91 bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1;
92 const float *beta_data = getTensorData<float>(
beta());
94 bool single_beta = beta_shape.DimensionsCount() == 1 && beta_shape.Dims(0) == 1;
95 float *output_data = getTensorData<float>(
output());
97 if (input_shape.DimensionsCount() == 4)
100 const int32_t batches = tflite::MatchingDim(input_shape, 0,
output_shape, 0);
101 const int32_t heights = tflite::MatchingDim(input_shape, 1,
output_shape, 1);
102 const int32_t widths = tflite::MatchingDim(input_shape, 2,
output_shape, 2);
103 const int32_t channels = tflite::MatchingDim(input_shape, 3,
output_shape, 3);
104 for (int32_t batch = 0; batch < batches; batch++)
106 for (int32_t channel = 0; channel < channels; channel++)
109 double square_sum = 0.0f;
110 int32_t
size = heights * widths;
111 for (int32_t height = 0; height < heights; height++)
113 for (int32_t width = 0; width < widths; width++)
116 input_data[tflite::Offset(input_shape, batch, height, width, channel)];
118 square_sum += (input_val * input_val);
121 double mean = sum /
size;
122 double var = square_sum /
size - mean * mean;
124 double gamma = single_gamma ? gamma_data[0] : gamma_data[channel];
125 double beta = single_beta ? beta_data[0] : beta_data[channel];
126 double a =
gamma / (std::sqrt(var +
params().epsilon));
127 double b = -mean * a +
beta;
129 for (int32_t height = 0; height < heights; height++)
131 for (int32_t width = 0; width < widths; width++)
134 input_data[tflite::Offset(
output_shape, batch, height, width, channel)];
135 double output_value = input_value * a + b;
136 output_data[tflite::Offset(
output_shape, batch, height, width, channel)] =
137 tflite::ActivationFunctionWithMinMax((
float)output_value, activation_min,
144 else if (input_shape.DimensionsCount() == 3)
147 const int32_t batches = tflite::MatchingDim(input_shape, 0,
output_shape, 0);
148 const int32_t channels = tflite::MatchingDim(input_shape, 1,
output_shape, 1);
150 for (int32_t batch = 0; batch < batches; batch++)
152 for (int32_t channel = 0; channel < channels; channel++)
155 double square_sum = 0.0f;
157 static_cast<size_t>(batch * channels *
size) +
static_cast<size_t>(channel *
size);
158 for (int32_t i = 0; i <
size; i++)
162 square_sum += (input_val * input_val);
165 double var = square_sum /
size - mean * mean;
167 double gamma = single_gamma ? gamma_data[0] : gamma_data[channel];
168 double beta = single_beta ? beta_data[0] : beta_data[channel];
169 double a =
gamma / (std::sqrt(var +
params().epsilon));
170 double b = -mean * a +
beta;
172 for (int32_t i = 0; i <
size; i++)
175 double output_value = input_value * a +
b;
177 (
float)output_value, activation_min, activation_max);
183 throw std::runtime_error(
"luci-intp InstanceNorm unsupported rank.");
const InstanceNormParams & params() const
void resize(const Shape &new_shape)
void execute() const override
const Tensor * beta() const
const Tensor * input() const
void configure() override
const Tensor * gamma() const
InstanceNorm(const Tensor *input, const Tensor *gamma, const Tensor *beta, Tensor *output, const InstanceNormParams ¶ms)
#define LUCI_INTERPRETER_CHECK(cond)
__global uchar * offset(const Image *img, int x, int y)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void calculateActivationRange(Activation activation, T *activation_min, T *activation_max)