18#include "kernels/Utils.h"
33 assert(
axis()->element_type() == DataType::S32 ||
axis()->element_type() == DataType::S64);
34 assert(
input()->shape().num_dims() >= 1);
36 const int num_dims = input_shape.
num_dims();
41 assert(
axis()->shape().num_elements() == 1);
42 int axis_value = getTensorData<int32_t>(
axis())[0];
44 axis_value = axis_value + num_dims;
45 assert(axis_value >= 0);
48 for (
int i = 0; i < num_dims; i++)
63#define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \
64 luci_interpreter_pal::ArgMinMax(getTensorShape(input()), getTensorData<data_type>(input()), \
65 getTensorData<axis_type>(axis()), getTensorShape(output()), \
66 getTensorData<output_type>(output()), std::greater<data_type>())
67 if (
axis()->element_type() == DataType::S32)
72 switch (
input()->element_type())
74 case DataType::FLOAT32:
81 throw std::runtime_error(
"Unsupported input type.");
85 switch (
input()->element_type())
87 case DataType::FLOAT32:
94 throw std::runtime_error(
"Unsupported input type.");
98 throw std::runtime_error(
"Unsupported output type.");
106 switch (
input()->element_type())
108 case DataType::FLOAT32:
115 throw std::runtime_error(
"Unsupported input type.");
119 switch (
input()->element_type())
121 case DataType::FLOAT32:
128 throw std::runtime_error(
"Unsupported input type.");
132 throw std::runtime_error(
"Unsupported output type.");
135#undef TF_LITE_ARG_MAX
const ArgMaxParams _params
void resize(const Shape &new_shape)
const Shape & shape() const
const Tensor * axis() const
void execute() const override
const Tensor * input() const
void configure() override
ArgMax(const Tensor *input, const Tensor *axis, Tensor *output, const ArgMaxParams ¶ms)
#define TF_LITE_ARG_MAX(data_type, axis_type, output_type)
const luci_interpreter::RuntimeShape output_shape