66 const auto *input = kernel.
input1();
67 const auto *axis = kernel.
input2();
68 const auto *output = kernel.
output();
70 const auto *options = cur_op->builtin_options_as_ReducerOptions();
72 int num_axis =
static_cast<int>(Tensor::num_elements(axis));
73 int temp_index[kMaxNumberOfAxis];
74 int resolved_axis[kMaxNumberOfReducedAxis];
76 switch (Tensor::element_type(kernel.
input1()))
79 case DataType::FLOAT32:
82 ResolveAxis(kernels::getTensorData<int>(tiso_data.
input2_data), num_axis, &op_params);
86 bool special_case_4d_axes_1_and_2 = Tensor::num_dims(input) == 4 &&
88 ((op_params.
axis[0] == 1 && op_params.
axis[1] == 2) ||
89 (op_params.
axis[0] == 2 && op_params.
axis[1] == 1));
92 if (options->keep_dims() && special_case_4d_axes_1_and_2)
95 kernels::getTensorData<float>(tiso_data.
input1_data),
97 kernels::getTensorData<float>(tiso_data.
output_data));
102 kernels::getTensorData<float>(tiso_data.
input1_data),
103 reinterpret_cast<const int *
>(
wrap(input->shape()).data()), Tensor::num_dims(input),
104 kernels::getTensorData<float>(tiso_data.
output_data),
105 reinterpret_cast<const int *
>(
wrap(output->shape()).data()), Tensor::num_dims(output),
106 kernels::getTensorData<int>(tiso_data.
input2_data), num_axis, options->keep_dims(),
107 temp_index, resolved_axis, kernels::getTensorData<float>(tiso_data.
output_data));
113 assert(
false &&
"Unsupported type");
bool Mean(const T *input_data, const int *input_dims, const int input_num_dims, T *output_data, const int *output_dims, const int output_num_dims, const int *axis, const int num_axis_dimensions, bool, int *temp_index, int *resolved_axis, U *temp_sum)