18#include "kernels/Utils.h"
35 for (;
i < axis_count; ++
i)
53 Tensor::element_type(kernel.
output()));
66 const auto *input = kernel.
input1();
67 const auto *axis = kernel.
input2();
68 const auto *output = kernel.
output();
70 const auto *options =
cur_op->builtin_options_as_ReducerOptions();
72 int num_axis =
static_cast<int>(Tensor::num_elements(axis));
76 switch (Tensor::element_type(kernel.
input1()))
79 case DataType::FLOAT32:
95 kernels::getTensorData<float>(
tiso_data.input1_data),
97 kernels::getTensorData<float>(
tiso_data.output_data));
102 kernels::getTensorData<float>(
tiso_data.input1_data),
103 reinterpret_cast<const int *
>(
wrap(input->shape()).data()), Tensor::num_dims(input),
104 kernels::getTensorData<float>(
tiso_data.output_data),
105 reinterpret_cast<const int *
>(
wrap(output->shape()).data()), Tensor::num_dims(output),
106 kernels::getTensorData<int>(
tiso_data.input2_data),
num_axis, options->keep_dims(),
113 assert(
false &&
"Unsupported type");
uint8_t * getConstDataByTensor(const circle::Tensor *raw_tensor)
const circle::Tensor * output() const
const circle::Tensor * input2() const
const circle::Tensor * input1() const
#define LUCI_INTERPRETER_CHECK(cond)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
bool Mean(const T *input_data, const int *input_dims, const int input_num_dims, T *output_data, const int *output_dims, const int output_num_dims, const int *axis, const int num_axis_dimensions, bool, int *temp_index, int *resolved_axis, U *temp_sum)
void execute_kernel_CircleMean(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
void configure_kernel_CircleMean(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
T must_cast(loco::Node *node)
VectorWrapper< T > wrap(const flatbuffers::Vector< T > *vec)
bool ResolveAxis(const int num_dims, const std::vector< int > &axes, int *out_axis, int *out_num_axis)