19#include "kernels/Utils.h"
32 Tensor::num_elements(kernel.
input2()) == 1);
43 const circle::Tensor *input = kernel.
input1();
44 const circle::Tensor *output = kernel.
output();
47 const auto input_data =
tiso_data.input1_data;
51 switch (Tensor::element_type(input))
54 case DataType::FLOAT32:
56 luci_interpreter_pal::ArgMinMax(
58 kernels::getTensorData<float>(input_data), kernels::getTensorData<int32_t>(
axis_data),
60 kernels::getTensorData<int32_t>(output_data), std::greater<float>());
65 assert(
false &&
"Unsupported ArgMax input type");
const circle::Tensor * output() const
const circle::Tensor * input2() const
const circle::Tensor * input1() const
#define LUCI_INTERPRETER_CHECK(cond)
luci_interpreter::RuntimeShape getTensorRuntimeShape(const circle::Tensor *circle_tensor, BaseRuntimeGraph *runtime_graph)
void configure_kernel_CircleArgMax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
void execute_kernel_CircleArgMax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
T must_cast(loco::Node *node)