20#include "kernels/Utils.h"
22#include <tensorflow/lite/kernels/internal/reference/reduce.h>
32static int getAxisReductionCount(
const int32_t *axes_data,
int num_axes,
int input_num_dims)
34 int reduction_count = num_axes;
35 for (
int i = 0; i < num_axes; ++i)
37 int current = axes_data[i] >= 0 ? axes_data[i] : axes_data[i] + input_num_dims;
38 assert(current >= 0 && current < input_num_dims);
39 for (
int j = 0; j < i; j++)
41 int previous = axes_data[j] >= 0 ? axes_data[j] : axes_data[j] + input_num_dims;
43 if (current == previous)
50 return reduction_count;
53static Shape getOutputShape(
const Shape &input_shape,
const int32_t *axes_data,
int num_axes,
56 int input_num_dims = input_shape.num_dims();
57 if (input_num_dims == 0)
65 for (
int idx = 0; idx < input_num_dims; ++idx)
68 for (
int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
70 if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
89 int num_reduce_axes = getAxisReductionCount(axes_data, num_axes, input_num_dims);
91 int num_skip_axes = 0;
92 for (
int idx = 0; idx < input_num_dims; ++idx)
95 for (
int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
97 if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
106 output_shape.dim(idx - num_skip_axes) = input_shape.dim(idx);
125 int input_num_dims = input_shape.
num_dims();
127 const auto *axes_data = getTensorData<int32_t>(
axes());
140 temp_index->resize(
Shape(input_num_dims));
141 resolved_axes->resize(
Shape(num_axes));
146 switch (
input()->element_type())
148 case DataType::FLOAT32:
152 throw std::runtime_error(
"luci-intp Sum Unsupported type.");
156void Sum::evalFloat()
const
158 const auto *axes_data = getTensorData<int32_t>(
axes());
164 int num_resolved_axis = 0;
166 tflite::reference_ops::ResolveAxis(
input()->shape().num_dims(), axes_data, num_axes,
167 getTensorData<int>(resolved_axes), &num_resolved_axis));
169 float init_value = 0.0;
170 tflite::reference_ops::ReduceGeneric<float>(
174 getTensorData<int>(temp_index), getTensorData<int>(resolved_axes), init_value,
175 [](
const float current,
const float in) ->
float { return current + in; });
const std::vector< Tensor * > & getOutputTensors() const
const ReducerParams _params
int32_t num_elements() const
void resize(const Shape &new_shape)
const Shape & shape() const
Sum(const Tensor *input, const Tensor *axes, Tensor *output, Tensor *temp_index, Tensor *resolved_axes, const ReducerParams ¶ms)
void configure() override
const Tensor * axes() const
const Tensor * input() const
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)