45 const auto reductor = [](T result, T x) {
return result + x; };
55 std::vector<bool> reduction_dims_mask(input_shape.rank(),
false);
56 for (
const int dim : reduction_dims)
58 reduction_dims_mask[dim] =
true;
64 int out_index_dim = 0;
65 for (
int dim = 0; dim < input_shape.rank(); ++dim)
69 out_index.
at(out_index_dim++) = reduction_dims_mask[dim] ? 0 : in_index.at(dim);
73 if (!reduction_dims_mask[dim])
75 out_index.
at(out_index_dim++) = in_index.at(dim);
79 res_accessor.
at(out_index) = reductor(res_accessor.
at(out_index), input.at(in_index));
82 const std::int32_t reduction_factor = input_shape.numElements() /
output_shape.numElements();
86 res_accessor.
at(index) /= reduction_factor;