39 const int axis_size = input1_shape.
Dims(axis);
42 for (
int i = 0; i < axis; ++i)
45 outer_size *= input1_shape.
Dims(i);
50 for (
int i = axis + 1; i < dims_count; ++i)
53 inner_size *= input1_shape.
Dims(i);
55 for (
int outer = 0; outer < outer_size; ++outer)
57 for (
int inner = 0; inner < inner_size; ++inner)
59 auto min_max_value = input1_data[outer * axis_size * inner_size + inner];
61 for (
int i = 1; i < axis_size; ++i)
63 const auto &curr_value = input1_data[(outer * axis_size + i) * inner_size + inner];
64 if (cmp(curr_value, min_max_value))
66 min_max_value = curr_value;
67 min_max_index =
static_cast<T2
>(i);
70 output_data[outer * inner_size + inner] = min_max_index;
void ArgMinMax(const Shape &input1_shape, const T1 *input1_data, const Shape &output_shape, T2 *output_data, int32_t axis, const Cmp &cmp)