ONE - On-device Neural Engine
Loading...
Searching...
No Matches
ArgMax.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "kernels/ArgMax.h"
18#include "kernels/Utils.h"
19#include "PALArgMax.h"
20
21namespace luci_interpreter
22{
23namespace kernels
24{
25
26ArgMax::ArgMax(const Tensor *input, const Tensor *axis, Tensor *output, const ArgMaxParams &params)
27 : KernelWithParams<ArgMaxParams>({input, axis}, {output}, params)
28{
29}
30
32{
33 assert(axis()->element_type() == DataType::S32 || axis()->element_type() == DataType::S64);
34 assert(input()->shape().num_dims() >= 1);
35 const Shape &input_shape = input()->shape();
36 const int num_dims = input_shape.num_dims();
37 Shape output_shape(num_dims - 1);
38
39 // If axis value is negative, then update by adding input_shape's num_dims.
40 // If updated value also negative, then assert.
41 assert(axis()->shape().num_elements() == 1);
42 int axis_value = getTensorData<int32_t>(axis())[0];
43 if (axis_value < 0)
44 axis_value = axis_value + num_dims;
45 assert(axis_value >= 0);
46
47 int j = 0;
48 for (int i = 0; i < num_dims; i++)
49 {
50 if (i == axis_value)
51 continue;
52 output_shape.dim(j++) = input_shape.dim(i);
53 }
54
55 assert(output()->element_type() == _params.output_type);
56
58}
59
60void ArgMax::execute() const
61{
62
63#define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \
64 luci_interpreter_pal::ArgMinMax(getTensorShape(input()), getTensorData<data_type>(input()), \
65 getTensorData<axis_type>(axis()), getTensorShape(output()), \
66 getTensorData<output_type>(output()), std::greater<data_type>())
67 if (axis()->element_type() == DataType::S32)
68 {
69 switch (_params.output_type)
70 {
71 case DataType::S32:
72 switch (input()->element_type())
73 {
74 case DataType::FLOAT32:
75 TF_LITE_ARG_MAX(float, int32_t, int32_t);
76 break;
77 case DataType::U8:
78 TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t);
79 break;
80 default:
81 throw std::runtime_error("Unsupported input type.");
82 }
83 break;
84 case DataType::S64:
85 switch (input()->element_type())
86 {
87 case DataType::FLOAT32:
88 TF_LITE_ARG_MAX(float, int32_t, int64_t);
89 break;
90 case DataType::U8:
91 TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t);
92 break;
93 default:
94 throw std::runtime_error("Unsupported input type.");
95 }
96 break;
97 default:
98 throw std::runtime_error("Unsupported output type.");
99 }
100 }
101 else
102 {
103 switch (_params.output_type)
104 {
105 case DataType::S32:
106 switch (input()->element_type())
107 {
108 case DataType::FLOAT32:
109 TF_LITE_ARG_MAX(float, int64_t, int32_t);
110 break;
111 case DataType::U8:
112 TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t);
113 break;
114 default:
115 throw std::runtime_error("Unsupported input type.");
116 }
117 break;
118 case DataType::S64:
119 switch (input()->element_type())
120 {
121 case DataType::FLOAT32:
122 TF_LITE_ARG_MAX(float, int64_t, int64_t);
123 break;
124 case DataType::U8:
125 TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t);
126 break;
127 default:
128 throw std::runtime_error("Unsupported input type.");
129 }
130 break;
131 default:
132 throw std::runtime_error("Unsupported output type.");
133 }
134 }
135#undef TF_LITE_ARG_MAX
136}
137
138} // namespace kernels
139} // namespace luci_interpreter
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const Tensor * axis() const
Definition ArgMax.h:34
void execute() const override
Definition ArgMax.cpp:60
const Tensor * input() const
Definition ArgMax.h:33
ArgMax(const Tensor *input, const Tensor *axis, Tensor *output, const ArgMaxParams &params)
Definition ArgMax.cpp:26
#define TF_LITE_ARG_MAX(data_type, axis_type, output_type)
const luci_interpreter::RuntimeShape output_shape