ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALReduceCommon.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef ONERT_MICRO_PAL_REDUCE_COMMON_H
19#define ONERT_MICRO_PAL_REDUCE_COMMON_H
20
21#include "PALUtils.h"
23#include "core/OMRuntimeShape.h"
24
25using namespace onert_micro::core;
26
27namespace onert_micro
28{
29namespace execute
30{
31namespace pal
32{
33
34// clang-format off
35
36// ------------------------------------------------------------------------------------------------
37
38template <class T>
40{
41 constexpr static T InitValue = T(0);
42
43 T operator()(const T current, const T in)
44 {
45 return in + current;
46 }
47};
48
49template <>
50struct ReduceSumFn<int8_t>
51{
52 constexpr static float InitValue = 0.f;
53
54 float operator()(const float current, const float in)
55 {
56 return in + current;
57 }
58};
59
60// ------------------------------------------------------------------------------------------------
61
62template <typename T>
64{
65 constexpr static T InitValue = T(1);
66
67 T operator()(const T current, const T in)
68 {
69 return in * current;
70 }
71};
72
73template <>
74struct ReduceProductFn<int8_t>
75{
76 constexpr static float InitValue = 1.f;
77
78 float operator()(const float current, const float in)
79 {
80 return in * current;
81 }
82};
83
84// ------------------------------------------------------------------------------------------------
85
86// clang-format on
87
88// This method parses the input 'axis' to remove duplicates and handle negative
89// values, and returns a valid 'out_axis'
90inline bool resolveAxis(const int num_dims, const int *axis, const int64_t num_axis, int *out_axis,
91 int *out_num_axis)
92{
93 *out_num_axis = 0; // Just in case.
94 // Short-circuit axis resolution for scalars; the axis will go unused.
95 if (num_dims == 0)
96 {
97 return true;
98 }
99 // o(n^2) is fine since out_num_axis should be really small, mostly <= 4
100 for (int64_t idx = 0; idx < num_axis; ++idx)
101 {
102 // Handle negative index. A positive index 'p_idx' can be represented as a
103 // negative index 'n_idx' as: n_idx = p_idx-num_dims
104 // eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1] */
105 int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
106 if (current < 0 || current >= num_dims)
107 {
108 return false;
109 }
110 bool is_dup = false;
111 for (int j = 0; j < *out_num_axis; ++j)
112 {
113 if (out_axis[j] == current)
114 {
115 is_dup = true;
116 break;
117 }
118 }
119 if (!is_dup)
120 {
121 if (*out_num_axis > 1)
122 {
123 return false;
124 }
125 out_axis[*out_num_axis] = current;
126 *out_num_axis += 1;
127 }
128 }
129 return true;
130}
131
132// ------------------------------------------------------------------------------------------------
133
134// Old version (used in Sum and ReduceProd), to be replaced with the new one (see below).
135// Computes the generic value (i.e., sum/max/min/prod) of elements across
136// dimensions given in axis. It needs to pass in init_value and reducer.
137template <typename T>
138inline bool ReduceGeneric(const T *input_data, const int *input_dims, const int input_num_dims,
139 T *output_data, const int *axis, const int64_t num_axis_dimensions,
140 T init_value, const int output_flat_size, T reducer(const T, const T))
141{
142 // Return early when input shape has zero dim.
143 for (int i = 0; i < input_num_dims; ++i)
144 {
145 if (input_dims[i] == 0)
146 return false;
147 }
148
149 for (size_t idx = 0; idx < output_flat_size; ++idx)
150 {
151 output_data[idx] = init_value;
152 }
153
154 // Resolve axis.
155 int num_resolved_axis = 0;
156 int resolved_axis[2];
157
158 if (!resolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, &num_resolved_axis))
159 {
160 return false;
161 }
162
163 int temp_index[5];
164 // Reset input iterator.
165 for (int idx = 0; idx < input_num_dims; ++idx)
166 {
167 temp_index[idx] = 0;
168 }
169 // Iterate through input_data.
170 do
171 {
172 size_t input_offset = reducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr);
173 size_t output_offset =
174 reducedOutputOffset(input_num_dims, input_dims, temp_index, num_resolved_axis, axis);
175 output_data[output_offset] = reducer(output_data[output_offset], input_data[input_offset]);
176 } while (nextIndex(input_num_dims, input_dims, temp_index));
177
178 return true;
179}
180
181// This method expects that output_data has been initialized.
182template <typename T>
183inline bool reduceSumImpl(const T *input_data, const int *input_dims, const int input_num_dims,
184 T *output_data, const int *axis, const int num_axis,
185 const int num_outputs)
186{
187 return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data, axis, num_axis,
188 static_cast<T>(0), num_outputs,
189 [](const T current, const T in) -> T { return in + current; });
190}
191
192// Mean over WH of axis 1,2
193template <typename T>
194inline void MeanROWH(const OMRuntimeShape &unextended_input_shape, const T *input_data,
195 const OMRuntimeShape &unextended_output_shape, T *output_data)
196{
197 // Current implementation only supports dimension equals 4 and simultaneous
198 // reduction over width and height.
199 const OMRuntimeShape input_shape = OMRuntimeShape::extendedShape(4, unextended_input_shape);
200 const OMRuntimeShape output_shape = OMRuntimeShape::extendedShape(4, unextended_output_shape);
201
202 const int output_batch = output_shape.dims(0);
203 const int output_depth = output_shape.dims(3);
204
205 const int input_height = input_shape.dims(1);
206 const int input_width = input_shape.dims(2);
207
208 for (int out_b = 0; out_b < output_batch; ++out_b)
209 {
210 for (int out_d = 0; out_d < output_depth; ++out_d)
211 {
212 float value = 0;
213 for (int in_h = 0; in_h < input_height; ++in_h)
214 {
215 for (int in_w = 0; in_w < input_width; ++in_w)
216 {
217 value += static_cast<float>(
218 input_data[offset(input_shape.dimsData(), out_b, in_h, in_w, out_d)]);
219 }
220 }
221 float result = value / (input_width * input_height);
222 output_data[offset(output_shape.dimsData(), out_b, 0, 0, out_d)] = static_cast<T>(result);
223 }
224 }
225}
226// New version (used in Mean).
227template <typename T, template <typename> class ReduceFn>
229{
230 auto &input = ctx.Input();
231 auto &input_data = input.Data();
232 auto input_dims = input.Dims();
233 auto input_num_dims = input.DimsCount();
234
235 auto &output = ctx.Output();
236 auto &output_data = output.Data();
237 auto output_flat_size = output.ShapeFlatSize();
238
239 auto &axis_ctx = ctx.Axis();
240 auto &axis = axis_ctx.Data();
241 auto num_axis_dimensions = axis_ctx.DimsCount();
242
243 // Return early when input shape has zero dim.
244 for (size_t i = 0; i < input_num_dims; ++i)
245 {
246 if (input_dims[i] == 0)
247 return false;
248 }
249
250 for (size_t idx = 0; idx < output_flat_size; ++idx)
251 {
252 output_data.SetValueAt(idx, ReduceFn<T>::InitValue);
253 }
254
255 // Resolve axis.
256 int num_resolved_axis = 0;
257 int resolved_axis[2];
258
259 if (!resolveAxis(input_num_dims, axis.Get(), num_axis_dimensions, resolved_axis,
260 &num_resolved_axis))
261 {
262 return false;
263 }
264
265 int temp_index[5];
266 // Reset input iterator.
267 for (size_t idx = 0; idx < input_num_dims; ++idx)
268 {
269 temp_index[idx] = 0;
270 }
271
272 // Iterate through input_data.
273 do
274 {
275 size_t input_offset = reducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr);
276 size_t output_offset =
277 reducedOutputOffset(input_num_dims, input_dims, temp_index, num_resolved_axis, axis.Get());
278
279 ReduceFn<T> reducer;
280 auto value = reducer(output_data.ValueAt(output_offset), input_data.ValueAt(input_offset));
281 output_data.SetValueAt(output_offset, value);
282
283 } while (nextIndex(input_num_dims, input_dims, temp_index));
284
285 return true;
286}
287
288// ------------------------------------------------------------------------------------------------
289
290template <typename T, template <typename> class ReduceFn>
291bool Reduce(OMReduceDataContext<T> &ctx, bool mean = false)
292{
293 // Special case mean implementation exists for 4D mean across axes 1
294 // and 2
295 const int *axis_value = ctx.Axis().Data().Get();
296 bool special_case_4d_axes_1_and_2 =
297 ctx.Input().DimsCount() == 4 && ctx.Axis().ShapeFlatSize() == 2 &&
298 ((axis_value[0] == 1 && axis_value[1] == 2) || (axis_value[0] == 2 && axis_value[1] == 1));
299 if (special_case_4d_axes_1_and_2)
300 {
301 OMRuntimeShape input_shape(ctx.Input().Shape());
302 OMRuntimeShape output_shape(ctx.Output().Shape());
303 MeanROWH<T>(input_shape, ctx.Input().Data().Get(), output_shape, ctx.Output().Data().Get());
304 return true;
305 }
306
307 constexpr static T kInitValue = T(0);
308
309 if (!ReduceGeneric<T, ReduceFn>(ctx))
310 {
311 return false;
312 }
313
314 auto &input = ctx.Input();
315 auto input_dims = input.Dims();
316 auto input_num_dims = input.DimsCount();
317
318 auto &output = ctx.Output();
319 auto &output_data = output.Data();
320 auto num_outputs = output.ShapeFlatSize();
321
322 auto &axis = ctx.Axis().Data();
323 auto num_axis_dimensions = ctx.Axis().DimsCount();
324
325 // Resolve axis again for computing mean
326 int num_resolved_axis = 0;
327 int resolved_axis[2];
328
329 if (!resolveAxis(input_num_dims, axis.Get(), num_axis_dimensions, resolved_axis,
330 &num_resolved_axis))
331 {
332 return false;
333 }
334
335 // clang-format off
336
337 auto fnReduceOutput = [&](size_t divide_by = 1)
338 {
339 for (size_t idx = 0; idx < num_outputs; ++idx)
340 {
341 auto value = output_data.ValueAt(idx);
342 value /= static_cast<T>(divide_by);
343 output_data.SetAt(idx, value);
344 }
345 };
346
347 // clang-format on
348
349 if (!mean)
350 {
351 fnReduceOutput();
352 return true;
353 }
354
355 // Calculate mean by dividing output_data by num of aggregated element.
356 size_t num_elements_in_axis = 1;
357 for (int idx = 0; idx < num_resolved_axis; ++idx)
358 {
359 size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
360 // Overflow prevention.
361 if (current > (std::numeric_limits<size_t>::max() / num_elements_in_axis))
362 {
363 return false;
364 }
365 num_elements_in_axis *= current;
366 }
367
368 if (num_elements_in_axis > 0)
369 {
370 fnReduceOutput(num_elements_in_axis);
371 }
372
373 return true;
374}
375
376} // namespace pal
377} // namespace execute
378} // namespace onert_micro
379
380#endif // ONERT_MICRO_PAL_REDUCE_COMMON_H
int32_t dims(int i) const
Definition Tensor.h:108
static OMRuntimeShape extendedShape(size_t new_shape_size, const OMRuntimeShape &shape)
const luci_interpreter::RuntimeShape output_shape
bool Reduce(OMReduceDataContext< T > &ctx, bool mean=false)
bool resolveAxis(const int num_dims, const int *axis, const int64_t num_axis, int *out_axis, int *out_num_axis)
void MeanROWH(const OMRuntimeShape &unextended_input_shape, const T *input_data, const OMRuntimeShape &unextended_output_shape, T *output_data)
bool nextIndex(const int32_t num_dims, const int32_t *dims, int32_t *current)
Definition PALUtils.h:175
bool ReduceGeneric(const T *input_data, const int *input_dims, const int input_num_dims, T *output_data, const int *axis, const int64_t num_axis_dimensions, T init_value, const int output_flat_size, T reducer(const T, const T))
size_t reducedOutputOffset(const int32_t num_dims, const int32_t *dims, const int32_t *index, const int32_t num_axis, const int32_t *axis)
Definition PALUtils.h:143
int offset(const int32_t *dims_data, int i0, int i1, int i2, int i3)
Definition PALUtils.h:220
bool reduceSumImpl(const T *input_data, const int *input_dims, const int input_num_dims, T *output_data, const int *axis, const int num_axis, const int num_outputs)
float operator()(const float current, const float in)
T operator()(const T current, const T in)
float operator()(const float current, const float in)
T operator()(const T current, const T in)