18#include "kernels/Utils.h"
31void sumGeneric(kernels::TISOData *tiso_data,
const circle::Tensor *input,
32 const circle::Tensor *axis,
const circle::Tensor *output,
bool keep_dims)
34 const int input_rank = Tensor::num_dims(input);
35 const int num_axis = Tensor::num_elements(axis);
37 auto const input_dims =
wrap(
input->shape());
40 luci_interpreter_pal::ReduceGeneric<T>(
41 kernels::getTensorData<T>(tiso_data->input1_data),
42 reinterpret_cast<const int *
>(input_dims.data()), input_rank,
43 kernels::getTensorData<T>(tiso_data->output_data),
44 kernels::getTensorData<int>(tiso_data->input2_data), num_axis,
46 [](
const float current,
const float in) ->
float { return in + current; });
56 Tensor::element_type(kernel.
output()));
59 const int32_t axis_value =
69 const auto *input = kernel.
input1();
70 const auto *axis = kernel.
input2();
71 const auto *output = kernel.
output();
73 const auto *options = cur_op->builtin_options_as_ReducerOptions();
75 switch (Tensor::element_type(kernel.
input1()))
78 case DataType::FLOAT32:
79 sumGeneric<float>(&tiso_data, input, axis, output, options->keep_dims());
83 assert(
false &&
"Unsupported type");
uint8_t * getConstDataByTensor(const circle::Tensor *raw_tensor)
const circle::Tensor * output() const
const circle::Tensor * input2() const
const circle::Tensor * input1() const
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void configure_kernel_CircleSum(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
void execute_kernel_CircleSum(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
VectorWrapper< T > wrap(const flatbuffers::Vector< T > *vec)