20#include "kernels/Utils.h"
22#include <tensorflow/lite/kernels/internal/reference/reduce.h>
33static int getAxisReductionCount(
const int32_t *axes_data,
int num_axes,
int input_num_dims)
35 int reduction_count = num_axes;
36 for (
int i = 0; i < num_axes; ++i)
38 int current = axes_data[i] >= 0 ? axes_data[i] : axes_data[i] + input_num_dims;
39 assert(current >= 0 && current < input_num_dims);
40 for (
int j = 0; j < i; j++)
42 int previous = axes_data[j] >= 0 ? axes_data[j] : axes_data[j] + input_num_dims;
44 if (current == previous)
51 return reduction_count;
54static Shape getOutputShape(
const Shape &input_shape,
const int32_t *axes_data,
int num_axes,
57 int input_num_dims = input_shape.num_dims();
58 if (input_num_dims == 0)
66 for (
int idx = 0; idx < input_num_dims; ++idx)
69 for (
int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
71 if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
90 int num_reduce_axes = getAxisReductionCount(axes_data, num_axes, input_num_dims);
92 int num_skip_axes = 0;
93 for (
int idx = 0; idx < input_num_dims; ++idx)
96 for (
int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
98 if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
107 output_shape.dim(idx - num_skip_axes) = input_shape.dim(idx);
126 int input_num_dims = input_shape.
num_dims();
128 const auto *axes_data = getTensorData<int32_t>(
axes());
141 temp_index->resize(
Shape(input_num_dims));
142 resolved_axes->resize(
Shape(num_axes));
147 switch (
input()->element_type())
149 case DataType::FLOAT32:
157 throw std::runtime_error(
"luci-intp ReduceMax Unsupported type.");
161void ReduceMax::evalFloat()
const
163 const auto *axes_data = getTensorData<int32_t>(
axes());
169 int num_resolved_axis = 0;
171 tflite::reference_ops::ResolveAxis(
input()->shape().num_dims(), axes_data, num_axes,
172 getTensorData<int>(resolved_axes), &num_resolved_axis));
174 float init_value = std::numeric_limits<float>::lowest();
175 tflite::reference_ops::ReduceGeneric<float>(
179 getTensorData<int>(temp_index), getTensorData<int>(resolved_axes), init_value,
180 [](
const float current,
const float in) ->
float { return (in > current) ? in : current; });
183void ReduceMax::evalBool()
const
185 const auto *axes_data = getTensorData<int32_t>(
axes());
191 int num_resolved_axis = 0;
193 tflite::reference_ops::ResolveAxis(
input()->shape().num_dims(), axes_data, num_axes,
194 getTensorData<int>(resolved_axes), &num_resolved_axis));
196 bool init_value = std::numeric_limits<bool>::lowest();
197 tflite::reference_ops::ReduceGeneric<bool>(
201 getTensorData<int>(temp_index), getTensorData<int>(resolved_axes), init_value,
202 [](
const bool current,
const bool in) ->
bool { return (in > current) ? in : current; });
const std::vector< Tensor * > & getOutputTensors() const
const ReducerParams _params
int32_t num_elements() const
void resize(const Shape &new_shape)
const Shape & shape() const
void execute() const override
const Tensor * input() const
ReduceMax(const Tensor *input, const Tensor *axes, Tensor *output, Tensor *temp_index, Tensor *resolved_axes, const ReducerParams ¶ms)
void configure() override
const Tensor * axes() const
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)