18#ifndef ONERT_MICRO_PAL_REDUCE_COMMON_H
19#define ONERT_MICRO_PAL_REDUCE_COMMON_H
90inline bool resolveAxis(
const int num_dims,
const int *axis,
const int64_t num_axis,
int *out_axis,
100 for (int64_t idx = 0; idx < num_axis; ++idx)
105 int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
106 if (current < 0 || current >= num_dims)
111 for (
int j = 0; j < *out_num_axis; ++j)
113 if (out_axis[j] == current)
121 if (*out_num_axis > 1)
125 out_axis[*out_num_axis] = current;
138inline bool ReduceGeneric(
const T *input_data,
const int *input_dims,
const int input_num_dims,
139 T *output_data,
const int *axis,
const int64_t num_axis_dimensions,
140 T init_value,
const int output_flat_size, T reducer(
const T,
const T))
143 for (
int i = 0; i < input_num_dims; ++i)
145 if (input_dims[i] == 0)
149 for (
size_t idx = 0; idx < output_flat_size; ++idx)
151 output_data[idx] = init_value;
155 int num_resolved_axis = 0;
156 int resolved_axis[2];
158 if (!
resolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, &num_resolved_axis))
165 for (
int idx = 0; idx < input_num_dims; ++idx)
172 size_t input_offset =
reducedOutputOffset(input_num_dims, input_dims, temp_index, 0,
nullptr);
173 size_t output_offset =
175 output_data[output_offset] = reducer(output_data[output_offset], input_data[input_offset]);
176 }
while (
nextIndex(input_num_dims, input_dims, temp_index));
183inline bool reduceSumImpl(
const T *input_data,
const int *input_dims,
const int input_num_dims,
184 T *output_data,
const int *axis,
const int num_axis,
185 const int num_outputs)
187 return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data, axis, num_axis,
188 static_cast<T
>(0), num_outputs,
189 [](
const T current,
const T in) -> T {
return in + current; });
205 const int input_height = input_shape.
dims(1);
206 const int input_width = input_shape.
dims(2);
208 for (
int out_b = 0; out_b < output_batch; ++out_b)
210 for (
int out_d = 0; out_d < output_depth; ++out_d)
213 for (
int in_h = 0; in_h < input_height; ++in_h)
215 for (
int in_w = 0; in_w < input_width; ++in_w)
217 value +=
static_cast<float>(
218 input_data[
offset(input_shape.
dimsData(), out_b, in_h, in_w, out_d)]);
221 float result = value / (input_width * input_height);
227template <
typename T,
template <
typename>
class ReduceFn>
230 auto &input = ctx.
Input();
231 auto &input_data = input.Data();
232 auto input_dims = input.Dims();
233 auto input_num_dims = input.DimsCount();
235 auto &output = ctx.
Output();
236 auto &output_data = output.Data();
237 auto output_flat_size = output.ShapeFlatSize();
239 auto &axis_ctx = ctx.Axis();
240 auto &axis = axis_ctx.Data();
241 auto num_axis_dimensions = axis_ctx.DimsCount();
244 for (
size_t i = 0; i < input_num_dims; ++i)
246 if (input_dims[i] == 0)
250 for (
size_t idx = 0; idx < output_flat_size; ++idx)
252 output_data.SetValueAt(idx, ReduceFn<T>::InitValue);
256 int num_resolved_axis = 0;
257 int resolved_axis[2];
259 if (!
resolveAxis(input_num_dims, axis.Get(), num_axis_dimensions, resolved_axis,
267 for (
size_t idx = 0; idx < input_num_dims; ++idx)
275 size_t input_offset =
reducedOutputOffset(input_num_dims, input_dims, temp_index, 0,
nullptr);
276 size_t output_offset =
277 reducedOutputOffset(input_num_dims, input_dims, temp_index, num_resolved_axis, axis.Get());
280 auto value = reducer(output_data.ValueAt(output_offset), input_data.ValueAt(input_offset));
281 output_data.SetValueAt(output_offset, value);
283 }
while (
nextIndex(input_num_dims, input_dims, temp_index));
290template <
typename T,
template <
typename>
class ReduceFn>
295 const int *axis_value = ctx.Axis().Data().Get();
296 bool special_case_4d_axes_1_and_2 =
297 ctx.
Input().DimsCount() == 4 && ctx.Axis().ShapeFlatSize() == 2 &&
298 ((axis_value[0] == 1 && axis_value[1] == 2) || (axis_value[0] == 2 && axis_value[1] == 1));
299 if (special_case_4d_axes_1_and_2)
307 constexpr static T kInitValue = T(0);
309 if (!ReduceGeneric<T, ReduceFn>(ctx))
314 auto &input = ctx.
Input();
315 auto input_dims = input.Dims();
316 auto input_num_dims = input.DimsCount();
318 auto &output = ctx.
Output();
319 auto &output_data = output.Data();
320 auto num_outputs = output.ShapeFlatSize();
322 auto &axis = ctx.Axis().Data();
323 auto num_axis_dimensions = ctx.Axis().DimsCount();
326 int num_resolved_axis = 0;
327 int resolved_axis[2];
329 if (!
resolveAxis(input_num_dims, axis.Get(), num_axis_dimensions, resolved_axis,
337 auto fnReduceOutput = [&](
size_t divide_by = 1)
339 for (
size_t idx = 0; idx < num_outputs; ++idx)
341 auto value = output_data.ValueAt(idx);
342 value /=
static_cast<T
>(divide_by);
343 output_data.SetAt(idx, value);
356 size_t num_elements_in_axis = 1;
357 for (
int idx = 0; idx < num_resolved_axis; ++idx)
359 size_t current =
static_cast<size_t>(input_dims[resolved_axis[idx]]);
361 if (current > (std::numeric_limits<size_t>::max() / num_elements_in_axis))
365 num_elements_in_axis *= current;
368 if (num_elements_in_axis > 0)
370 fnReduceOutput(num_elements_in_axis);
int32_t dims(int i) const
static OMRuntimeShape extendedShape(size_t new_shape_size, const OMRuntimeShape &shape)
int32_t dims(size_t i) const
const luci_interpreter::RuntimeShape output_shape
bool Reduce(OMReduceDataContext< T > &ctx, bool mean=false)
bool resolveAxis(const int num_dims, const int *axis, const int64_t num_axis, int *out_axis, int *out_num_axis)
void MeanROWH(const OMRuntimeShape &unextended_input_shape, const T *input_data, const OMRuntimeShape &unextended_output_shape, T *output_data)
bool nextIndex(const int32_t num_dims, const int32_t *dims, int32_t *current)
bool ReduceGeneric(const T *input_data, const int *input_dims, const int input_num_dims, T *output_data, const int *axis, const int64_t num_axis_dimensions, T init_value, const int output_flat_size, T reducer(const T, const T))
size_t reducedOutputOffset(const int32_t num_dims, const int32_t *dims, const int32_t *index, const int32_t num_axis, const int32_t *axis)
int offset(const int32_t *dims_data, int i0, int i1, int i2, int i3)
bool reduceSumImpl(const T *input_data, const int *input_dims, const int input_num_dims, T *output_data, const int *axis, const int num_axis, const int num_outputs)
float operator()(const float current, const float in)
static constexpr T InitValue
T operator()(const T current, const T in)
float operator()(const float current, const float in)
static constexpr T InitValue
T operator()(const T current, const T in)