18#ifndef ONERT_MICRO_PAL_REDUCE_COMMON_H
19#define ONERT_MICRO_PAL_REDUCE_COMMON_H
27#include <unordered_map>
33using core::type_traits::IsQuantized;
69 total = std::max(total, value);
75template <
typename T,
template <
typename>
class ReduceFn>
78 using ValueType = std::conditional_t<IsQuantized<T>, float, T>;
86 ReduceFn<ValueType> _reducer;
88 std::unordered_map<size_t, size_t> _curr_index = {};
89 std::unordered_map<size_t, uint32_t> _resolved_axes = {};
90 std::unordered_map<size_t, ValueType> _accumulator = {};
95 , _output(ctx.Output())
97 , _init_value(init_value)
103 if (SpecialCaseMeanImpl())
106 return ReduceImpl(
true);
115 bool ReduceImpl(
bool mean =
false);
116 bool SpecialCaseMeanImpl();
119 T ResolvedAxisLength();
121 size_t ReducedOutputOffset(
int num_axes,
const uint32_t *axes);
127template <
typename T,
template <
typename>
class ReduceFn>
128bool Reducer<T, ReduceFn>::ResolveAxis()
130 size_t num_resolved_axes = 0;
131 _resolved_axes.clear();
133 if (_input.IsScalar())
136 for (
size_t i = 0; i < _axes.ElementsCount(); ++i)
138 int current = _axes.Data().At(i);
140 if (_resolved_axes.count(current) > 0)
143 if (_resolved_axes.size() > 1)
146 _resolved_axes[num_resolved_axes++] = current;
152template <
typename T,
template <
typename>
class ReduceFn>
153bool Reducer<T, ReduceFn>::SpecialCaseMeanImpl()
159 const uint32_t *axes_data = _axes.Data().Get();
160 std::set<uint32_t> axes_values = { axes_data[0], axes_data[1] };
162 if (_input.DimsCount() != 4)
165 if (_axes.ElementsCount() != 2)
168 if (axes_values.count(1) != 1 || axes_values.count(2) != 1)
171 auto input_shape = OMRuntimeShape::extendedShape(4, _input.Shape());
172 auto output_shape = OMRuntimeShape::extendedShape(4, _output.Shape());
177 const int input_height = input_shape.dims(1);
178 const int input_width = input_shape.dims(2);
180 for (
int out_b = 0; out_b < output_batch; ++out_b)
182 for (
int out_d = 0; out_d < output_depth; ++out_d)
186 for (
int in_h = 0; in_h < input_height; ++in_h)
188 for (
int in_w = 0; in_w < input_width; ++in_w)
190 size_t idx =
offset(input_shape.dimsData(), out_b, in_h, in_w, out_d);
191 value +=
static_cast<float>(_input.Data().At(idx));
195 float result = value / (input_width * input_height);
197 _output.Data().SetAt(idx, result);
204template <
typename T,
template <
typename>
class ReduceFn>
205bool Reducer<T, ReduceFn>::ReduceImpl(
bool mean)
207 _accumulator.clear();
210 auto *axes_data = _axes.Data().Get();
215 if (_input.HasZeroSizeDims())
222 _accumulator[i] = _init_value;
235 _reducer(_accumulator[output_offset],
input_data.ValueAt(input_offset));
241 auto value = _accumulator.at(i);
245 value /= ResolvedAxisLength();
248 _output.Data().SetValueAt(i, value);
254template <
typename T,
template <
typename>
class ReduceFn>
255T Reducer<T, ReduceFn>::ResolvedAxisLength()
258 constexpr static auto kMax = std::numeric_limits<size_t>::max();
260 for (
auto i = 0u; i < _resolved_axes.size(); ++i)
262 auto &axis = _resolved_axes.at(i);
263 auto current =
static_cast<size_t>(_input.Dims()[axis]);
269 if (current > (kMax / axis_length))
272 axis_length *= current;
275 return static_cast<T
>(axis_length);
284template <
typename T,
template <
typename>
class ReduceFn>
285size_t Reducer<T, ReduceFn>::ReducedOutputOffset(
int num_axes,
const uint32_t *axes_data)
289 for (
auto dim_idx = 0u; dim_idx < _input.DimsCount(); ++dim_idx)
291 bool skip_axis =
false;
293 if (axes_data !=
nullptr)
295 skip_axis = std::any_of(axes_data, axes_data + num_axes, [&dim_idx](
auto axis)
297 return axis == dim_idx;
303 offset *= _input.DimLength(dim_idx);
304 offset += _curr_index[dim_idx];
314template <
typename T,
template <
typename>
class ReduceFn>
315bool Reducer<T, ReduceFn>::NextIndex()
317 if (_input.DimsCount() == 0)
322 for (
int idx = _input.DimsCount() - 1; idx >= 0; --idx)
324 auto current_val = _curr_index[idx] + 1;
326 if (_input.DimLength(idx) != current_val)
328 _curr_index[idx] = current_val;
332 _curr_index[idx] = 0;
int32_t dims(int i) const
Reducer(core::OMReduceDataContext< T > &ctx, T init_value)
const luci_interpreter::RuntimeShape output_shape
bool NextIndex(const int num_dims, const int *dims, int *current)
size_t ReducedOutputOffset(const int num_dims, const int *dims, const int *index, const int num_axis, const int *axis)
bool ResolveAxis(const int num_dims, const std::vector< int > &axes, int *out_axis, int *out_num_axis)
int offset(const int32_t *dims_data, int i0, int i1, int i2, int i3)
void operator()(T &total, const T value)
void operator()(T &total, const T value)
void operator()(T &total, const T value)