ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::ArgMax Class Reference

#include <ArgMax.h>

Collaboration diagram for luci_interpreter::kernels::ArgMax:

Public Member Functions

 ArgMax (const Tensor *input, const Tensor *axis, Tensor *output, const ArgMaxParams &params)
 
const Tensorinput () const
 
const Tensoraxis () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< ArgMaxParams >
const ArgMaxParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< ArgMaxParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const ArgMaxParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< ArgMaxParams >
const ArgMaxParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file ArgMax.h.

Constructor & Destructor Documentation

◆ ArgMax()

luci_interpreter::kernels::ArgMax::ArgMax ( const Tensor input,
const Tensor axis,
Tensor output,
const ArgMaxParams params 
)

Definition at line 26 of file ArgMax.cpp.

27 : KernelWithParams<ArgMaxParams>({input, axis}, {output}, params)
28{
29}
const ArgMaxParams & params() const
Definition Kernel.h:67
const Tensor * axis() const
Definition ArgMax.h:34
const Tensor * input() const
Definition ArgMax.h:33

References axis(), and input().

Member Function Documentation

◆ axis()

const Tensor * luci_interpreter::kernels::ArgMax::axis ( ) const
inline

Definition at line 34 of file ArgMax.h.

34{ return _inputs[1]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by ArgMax(), configure(), and execute().

◆ configure()

void luci_interpreter::kernels::ArgMax::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 31 of file ArgMax.cpp.

32{
33 assert(axis()->element_type() == DataType::S32 || axis()->element_type() == DataType::S64);
34 assert(input()->shape().num_dims() >= 1);
35 const Shape &input_shape = input()->shape();
36 const int num_dims = input_shape.num_dims();
37 Shape output_shape(num_dims - 1);
38
39 // If axis value is negative, then update by adding input_shape's num_dims.
40 // If updated value also negative, then assert.
41 assert(axis()->shape().num_elements() == 1);
42 int axis_value = getTensorData<int32_t>(axis())[0];
43 if (axis_value < 0)
44 axis_value = axis_value + num_dims;
45 assert(axis_value >= 0);
46
47 int j = 0;
48 for (int i = 0; i < num_dims; i++)
49 {
50 if (i == axis_value)
51 continue;
52 output_shape.dim(j++) = input_shape.dim(i);
53 }
54
55 assert(output()->element_type() == _params.output_type);
56
58}
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const luci_interpreter::RuntimeShape output_shape
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59
Definition Shape.h:28

References luci_interpreter::KernelWithParams< ArgMaxParams >::_params, axis(), luci_interpreter::Shape::dim(), input(), luci_interpreter::Shape::num_dims(), output(), output_shape, luci_interpreter::ArgMaxParams::output_type, luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ execute()

void luci_interpreter::kernels::ArgMax::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 60 of file ArgMax.cpp.

61{
62
63#define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \
64 luci_interpreter_pal::ArgMinMax(getTensorShape(input()), getTensorData<data_type>(input()), \
65 getTensorData<axis_type>(axis()), getTensorShape(output()), \
66 getTensorData<output_type>(output()), std::greater<data_type>())
67 if (axis()->element_type() == DataType::S32)
68 {
69 switch (_params.output_type)
70 {
71 case DataType::S32:
72 switch (input()->element_type())
73 {
74 case DataType::FLOAT32:
75 TF_LITE_ARG_MAX(float, int32_t, int32_t);
76 break;
77 case DataType::U8:
78 TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t);
79 break;
80 default:
81 throw std::runtime_error("Unsupported input type.");
82 }
83 break;
84 case DataType::S64:
85 switch (input()->element_type())
86 {
87 case DataType::FLOAT32:
88 TF_LITE_ARG_MAX(float, int32_t, int64_t);
89 break;
90 case DataType::U8:
91 TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t);
92 break;
93 default:
94 throw std::runtime_error("Unsupported input type.");
95 }
96 break;
97 default:
98 throw std::runtime_error("Unsupported output type.");
99 }
100 }
101 else
102 {
103 switch (_params.output_type)
104 {
105 case DataType::S32:
106 switch (input()->element_type())
107 {
108 case DataType::FLOAT32:
109 TF_LITE_ARG_MAX(float, int64_t, int32_t);
110 break;
111 case DataType::U8:
112 TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t);
113 break;
114 default:
115 throw std::runtime_error("Unsupported input type.");
116 }
117 break;
118 case DataType::S64:
119 switch (input()->element_type())
120 {
121 case DataType::FLOAT32:
122 TF_LITE_ARG_MAX(float, int64_t, int64_t);
123 break;
124 case DataType::U8:
125 TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t);
126 break;
127 default:
128 throw std::runtime_error("Unsupported input type.");
129 }
130 break;
131 default:
132 throw std::runtime_error("Unsupported output type.");
133 }
134 }
135#undef TF_LITE_ARG_MAX
136}
#define TF_LITE_ARG_MAX(data_type, axis_type, output_type)

References luci_interpreter::KernelWithParams< ArgMaxParams >::_params, axis(), input(), luci_interpreter::ArgMaxParams::output_type, and TF_LITE_ARG_MAX.

◆ input()

const Tensor * luci_interpreter::kernels::ArgMax::input ( ) const
inline

Definition at line 33 of file ArgMax.h.

33{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by ArgMax(), configure(), and execute().

◆ output()

Tensor * luci_interpreter::kernels::ArgMax::output ( ) const
inline

Definition at line 35 of file ArgMax.h.

35{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().


The documentation for this class was generated from the following files: