ONE - On-device Neural Engine
Loading...
Searching...
No Matches
ComparisonCommon.h
Go to the documentation of this file.
1
/*
2
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3
*
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
7
*
8
* http://www.apache.org/licenses/LICENSE-2.0
9
*
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
15
*/
16
17
#ifndef LUCI_INTERPRETER_KERNELS_COMPARISONCOMMON_H
18
#define LUCI_INTERPRETER_KERNELS_COMPARISONCOMMON_H
19
20
#include "
Builders.h
"
21
22
#include "kernels/Utils.h"
23
#include "PALComparisons.h"
24
25
namespace
luci_interpreter
26
{
27
namespace
kernels
28
{
29
30
template
<
typename
T>
31
void
evalComparisonGeneric
(
const
circle::Tensor *x,
const
circle::Tensor *y,
32
const
circle::Tensor *output,
BaseRuntimeGraph
*runtime_graph,
33
bool
F
(T, T))
34
{
35
auto
x_data
= kernels::getTensorData<T>(runtime_graph->
getDataByTensor
(x));
36
if
(
x_data
==
nullptr
)
37
x_data
= kernels::getTensorData<T>(runtime_graph->
getConstDataByTensor
(x));
38
39
assert(
x_data
!=
nullptr
);
40
41
auto
y_data
= kernels::getTensorData<T>(runtime_graph->
getDataByTensor
(y));
42
if
(
y_data
==
nullptr
)
43
y_data
= kernels::getTensorData<T>(runtime_graph->
getConstDataByTensor
(y));
44
45
assert(
y_data
!=
nullptr
);
46
47
auto
output_data = kernels::getTensorData<bool>(runtime_graph->
getDataByTensor
(output));
48
49
luci_interpreter_pal::ComparisonParams
op_params
;
50
op_params
.
is_broadcast
= Tensor::num_elements(x) != Tensor::num_elements(y);
51
52
if
(
op_params
.is_broadcast)
53
{
54
luci_interpreter_pal::BroadcastComparison4DSlowNoScaling<T>(
55
op_params
,
kernels::getTensorShape
(x),
x_data
,
kernels::getTensorShape
(y),
y_data
,
56
kernels::getTensorShape
(output), output_data,
F
);
57
}
58
else
59
{
60
const
int64_t
flat_size
=
kernels::getTensorShape
(x).flatSize();
61
luci_interpreter_pal::ComparisonNoScaling<T>(
flat_size
,
x_data
,
y_data
, output_data,
F
);
62
}
63
}
64
65
}
// namespace kernels
66
}
// namespace luci_interpreter
67
68
#endif
// LUCI_INTERPRETER_KERNELS_COMPARISONCOMMON_H
luci_interpreter::RuntimeGraph
Definition
RuntimeGraph.h:33
luci_interpreter::RuntimeGraph::getConstDataByTensor
uint8_t * getConstDataByTensor(const circle::Tensor *raw_tensor)
Definition
RuntimeGraph.cpp:398
luci_interpreter::RuntimeGraph::getDataByTensor
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
Definition
RuntimeGraph.cpp:355
luci_interpreter::kernels::evalComparisonGeneric
void evalComparisonGeneric(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output, BaseRuntimeGraph *runtime_graph, bool F(T, T))
Definition
ComparisonCommon.h:31
luci_interpreter::kernels::getTensorShape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition
Utils.h:194
luci_interpreter
Definition
BuddyMemoryManager.h:22
luci::must_cast
T must_cast(loco::Node *node)
Definition
CircleNodeDecl.h:95
Builders.h
luci_interpreter_pal::ComparisonParams
Definition
Params.h:147
luci_interpreter_pal::ComparisonParams::is_broadcast
bool is_broadcast
Definition
Params.h:157
onert-micro
luci-interpreter
src
kernels
ComparisonCommon.h
Generated by
1.9.8