ONE - On-device Neural Engine
Loading...
Searching...
No Matches
InstanceNorm.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "kernels/InstanceNorm.h"
18
19#include "kernels/Utils.h"
20
21#include <tensorflow/lite/kernels/internal/common.h>
22#include <cmath>
23
24namespace luci_interpreter
25{
26namespace kernels
27{
28
29InstanceNorm::InstanceNorm(const Tensor *input, const Tensor *gamma, const Tensor *beta,
30 Tensor *output, const InstanceNormParams &params)
31 : KernelWithParams<InstanceNormParams>({input, gamma, beta}, {output}, params)
32{
33}
34
36{
37 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
38 LUCI_INTERPRETER_CHECK(gamma()->element_type() == input()->element_type());
39 if (input()->shape().num_dims() == 4)
40 {
41 LUCI_INTERPRETER_CHECK(input()->shape().num_dims() == 4);
42 LUCI_INTERPRETER_CHECK(gamma()->shape().num_dims() == 1);
43 LUCI_INTERPRETER_CHECK(gamma()->shape().dim(0) == input()->shape().dim(3) ||
44 gamma()->shape().dim(0) == 1);
45 LUCI_INTERPRETER_CHECK(beta()->element_type() == input()->element_type());
46 LUCI_INTERPRETER_CHECK(beta()->shape().num_dims() == 1);
47 LUCI_INTERPRETER_CHECK(beta()->shape().dim(0) == input()->shape().dim(3) ||
48 beta()->shape().dim(0) == 1);
49 }
50 else if (input()->shape().num_dims() == 3)
51 {
52 LUCI_INTERPRETER_CHECK(input()->shape().num_dims() == 3);
53 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
54 LUCI_INTERPRETER_CHECK(gamma()->element_type() == input()->element_type());
55 LUCI_INTERPRETER_CHECK(gamma()->shape().num_dims() == 1);
56 LUCI_INTERPRETER_CHECK(gamma()->shape().dim(0) == input()->shape().dim(1) ||
57 gamma()->shape().dim(0) == 1);
58 LUCI_INTERPRETER_CHECK(beta()->element_type() == input()->element_type());
59 LUCI_INTERPRETER_CHECK(beta()->shape().num_dims() == 1);
60 LUCI_INTERPRETER_CHECK(beta()->shape().dim(0) == input()->shape().dim(1) ||
61 beta()->shape().dim(0) == 1);
62 }
63 else
64 LUCI_INTERPRETER_CHECK(false && "luci-intp InstanceNorm unsupported rank.");
65
66 output()->resize(input()->shape());
67}
68
70{
71 switch (input()->element_type())
72 {
73 case DataType::FLOAT32:
74 evalFloat();
75 break;
76 default:
77 throw std::runtime_error("luci-intp InstanceNorm Unsupported type.");
78 }
79}
80
81void InstanceNorm::evalFloat() const
82{
83 float activation_min, activation_max;
84 calculateActivationRange(params().activation, &activation_min, &activation_max);
85 tflite::RuntimeShape input_shape = getTensorShape(input());
87
88 const float *input_data = getTensorData<float>(input());
89 const float *gamma_data = getTensorData<float>(gamma());
90 auto gamma_shape = getTensorShape(gamma());
91 bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1;
92 const float *beta_data = getTensorData<float>(beta());
93 auto beta_shape = getTensorShape(beta());
94 bool single_beta = beta_shape.DimensionsCount() == 1 && beta_shape.Dims(0) == 1;
95 float *output_data = getTensorData<float>(output());
96
97 if (input_shape.DimensionsCount() == 4)
98 {
99 // Dimensions for image case are (N x H x W x C)
100 const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
101 const int32_t heights = tflite::MatchingDim(input_shape, 1, output_shape, 1);
102 const int32_t widths = tflite::MatchingDim(input_shape, 2, output_shape, 2);
103 const int32_t channels = tflite::MatchingDim(input_shape, 3, output_shape, 3);
104 for (int32_t batch = 0; batch < batches; batch++)
105 {
106 for (int32_t channel = 0; channel < channels; channel++)
107 {
108 double sum = 0.0f;
109 double square_sum = 0.0f;
110 int32_t size = heights * widths;
111 for (int32_t height = 0; height < heights; height++)
112 {
113 for (int32_t width = 0; width < widths; width++)
114 {
115 double input_val =
116 input_data[tflite::Offset(input_shape, batch, height, width, channel)];
117 sum += input_val;
118 square_sum += (input_val * input_val);
119 }
120 }
121 double mean = sum / size;
122 double var = square_sum / size - mean * mean;
123
124 double gamma = single_gamma ? gamma_data[0] : gamma_data[channel];
125 double beta = single_beta ? beta_data[0] : beta_data[channel];
126 double a = gamma / (std::sqrt(var + params().epsilon));
127 double b = -mean * a + beta;
128
129 for (int32_t height = 0; height < heights; height++)
130 {
131 for (int32_t width = 0; width < widths; width++)
132 {
133 double input_value =
134 input_data[tflite::Offset(output_shape, batch, height, width, channel)];
135 double output_value = input_value * a + b;
136 output_data[tflite::Offset(output_shape, batch, height, width, channel)] =
137 tflite::ActivationFunctionWithMinMax((float)output_value, activation_min,
138 activation_max);
139 }
140 }
141 }
142 }
143 }
144 else if (input_shape.DimensionsCount() == 3)
145 {
146 // Dimensions for non image case are (N x C x D1 x D2 … Dn)
147 const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
148 const int32_t channels = tflite::MatchingDim(input_shape, 1, output_shape, 1);
149 const int32_t size = tflite::MatchingDim(input_shape, 2, output_shape, 2);
150 for (int32_t batch = 0; batch < batches; batch++)
151 {
152 for (int32_t channel = 0; channel < channels; channel++)
153 {
154 double sum = 0.0f;
155 double square_sum = 0.0f;
156 size_t offset =
157 static_cast<size_t>(batch * channels * size) + static_cast<size_t>(channel * size);
158 for (int32_t i = 0; i < size; i++)
159 {
160 double input_val = input_data[offset + i];
161 sum += input_val;
162 square_sum += (input_val * input_val);
163 }
164 double mean = sum / size;
165 double var = square_sum / size - mean * mean;
166
167 double gamma = single_gamma ? gamma_data[0] : gamma_data[channel];
168 double beta = single_beta ? beta_data[0] : beta_data[channel];
169 double a = gamma / (std::sqrt(var + params().epsilon));
170 double b = -mean * a + beta;
171
172 for (int32_t i = 0; i < size; i++)
173 {
174 double input_value = input_data[offset + i];
175 double output_value = input_value * a + b;
176 output_data[offset + i] = tflite::ActivationFunctionWithMinMax(
177 (float)output_value, activation_min, activation_max);
178 }
179 }
180 }
181 }
182 else
183 throw std::runtime_error("luci-intp InstanceNorm unsupported rank.");
184}
185
186} // namespace kernels
187} // namespace luci_interpreter
const InstanceNormParams & params() const
Definition Kernel.h:67
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
InstanceNorm(const Tensor *input, const Tensor *gamma, const Tensor *beta, Tensor *output, const InstanceNormParams &params)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
__global uchar * offset(const Image *img, int x, int y)
Definition helpers.h:540
const luci_interpreter::RuntimeShape output_shape
list input_data
Definition infer.py:29
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void calculateActivationRange(Activation activation, T *activation_min, T *activation_max)
Definition Utils.cpp:52
int32_t size[5]
Definition Slice.cpp:35