ONE - On-device Neural Engine
Loading...
Searching...
No Matches
ComparisonCommon.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LUCI_INTERPRETER_KERNELS_COMPARISONCOMMON_H
18#define LUCI_INTERPRETER_KERNELS_COMPARISONCOMMON_H
19
20#include "Builders.h"
21
22#include "kernels/Utils.h"
23#include "PALComparisons.h"
24
25namespace luci_interpreter
26{
27namespace kernels
28{
29
30template <typename T>
31void evalComparisonGeneric(const circle::Tensor *x, const circle::Tensor *y,
32 const circle::Tensor *output, BaseRuntimeGraph *runtime_graph,
33 bool F(T, T))
34{
35 auto x_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(x));
36 if (x_data == nullptr)
37 x_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(x));
38
39 assert(x_data != nullptr);
40
41 auto y_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(y));
42 if (y_data == nullptr)
43 y_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(y));
44
45 assert(y_data != nullptr);
46
47 auto output_data = kernels::getTensorData<bool>(runtime_graph->getDataByTensor(output));
48
50 op_params.is_broadcast = Tensor::num_elements(x) != Tensor::num_elements(y);
51
52 if (op_params.is_broadcast)
53 {
54 luci_interpreter_pal::BroadcastComparison4DSlowNoScaling<T>(
55 op_params, kernels::getTensorShape(x), x_data, kernels::getTensorShape(y), y_data,
56 kernels::getTensorShape(output), output_data, F);
57 }
58 else
59 {
60 const int64_t flat_size = kernels::getTensorShape(x).flatSize();
61 luci_interpreter_pal::ComparisonNoScaling<T>(flat_size, x_data, y_data, output_data, F);
62 }
63}
64
65} // namespace kernels
66} // namespace luci_interpreter
67
68#endif // LUCI_INTERPRETER_KERNELS_COMPARISONCOMMON_H
uint8_t * getConstDataByTensor(const circle::Tensor *raw_tensor)
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
void evalComparisonGeneric(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output, BaseRuntimeGraph *runtime_graph, bool F(T, T))
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194