ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Cast.cpp
Go to the documentation of this file.
1
/*
2
* Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4
*
5
* Licensed under the Apache License, Version 2.0 (the "License");
6
* you may not use this file except in compliance with the License.
7
* You may obtain a copy of the License at
8
*
9
* http://www.apache.org/licenses/LICENSE-2.0
10
*
11
* Unless required by applicable law or agreed to in writing, software
12
* distributed under the License is distributed on an "AS IS" BASIS,
13
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
* See the License for the specific language governing permissions and
15
* limitations under the License.
16
*/
17
18
#include "
Builders.h
"
19
#include "kernels/Utils.h"
20
#include "
SISOKernel.h
"
21
22
namespace
luci_interpreter
23
{
24
namespace
25
{
26
27
template
<
typename
FromT,
typename
ToT>
void
copyCast
(
const
FromT
*in,
ToT
*out,
int
num_elements)
28
{
29
std::transform(in, in + num_elements, out, [](
FromT
a) {
return
static_cast<
ToT
>
(a); });
30
}
31
32
}
// namespace
33
34
void
configure_kernel_CircleCast
(
const
circle::Operator *
cur_op
,
BaseRuntimeGraph
*runtime_graph)
35
{
36
kernels::SISOKernel
kernel(
cur_op
, runtime_graph);
37
38
LUCI_INTERPRETER_CHECK
(Tensor::num_elements(kernel.
input
()) ==
39
Tensor::num_elements(kernel.
output
()));
40
LUCI_INTERPRETER_CHECK
(Tensor::num_dims(kernel.
input
()) == Tensor::num_dims(kernel.
output
()));
41
}
42
43
void
execute_kernel_CircleCast
(
const
circle::Operator *
cur_op
,
BaseRuntimeGraph
*runtime_graph)
44
{
45
kernels::SISOKernel
kernel(
cur_op
, runtime_graph);
46
47
const
auto
*input_data = runtime_graph->
getDataByTensor
(kernel.
input
());
48
assert(input_data);
49
50
auto
*output_data = runtime_graph->
getDataByTensor
(kernel.
output
());
51
assert(output_data);
52
53
const
int
flat_size
=
kernels::getTensorRuntimeShape
(kernel.
input
(), runtime_graph).
flatSize
();
54
55
switch
(Tensor::element_type(kernel.
input
()))
56
{
57
#ifndef DIS_FLOAT
58
case
DataType::FLOAT32:
59
{
60
const
float
*
input_data_float
= kernels::getTensorData<float>(input_data);
61
62
switch
(Tensor::element_type(kernel.
output
()))
63
{
64
case
DataType::S8:
65
copyCast
(
input_data_float
, kernels::getTensorData<int8_t>(output_data),
flat_size
);
66
break
;
67
case
DataType::S16:
68
copyCast
(
input_data_float
, kernels::getTensorData<int16_t>(output_data),
flat_size
);
69
break
;
70
case
DataType::S32:
71
copyCast
(
input_data_float
, kernels::getTensorData<int32_t>(output_data),
flat_size
);
72
break
;
73
default
:
74
assert(
false
&&
"Not supported type"
);
75
}
76
break
;
77
}
78
#endif
// DIS_FLOAT
79
default
:
80
assert(
false
&&
"Unsupported type"
);
81
}
82
}
83
}
// namespace luci_interpreter
SISOKernel.h
luci_interpreter::RuntimeGraph
Definition
RuntimeGraph.h:33
luci_interpreter::RuntimeGraph::getDataByTensor
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
Definition
RuntimeGraph.cpp:355
luci_interpreter::RuntimeShape::flatSize
int flatSize() const
Definition
Tensor.h:45
luci_interpreter::kernels::SISOKernel
Definition
SISOKernel.h:29
luci_interpreter::kernels::SISOKernel::output
const circle::Tensor * output() const
Definition
SISOKernel.h:47
luci_interpreter::kernels::SISOKernel::input
const circle::Tensor * input() const
Definition
SISOKernel.h:46
LUCI_INTERPRETER_CHECK
#define LUCI_INTERPRETER_CHECK(cond)
Definition
Utils.h:36
luci_interpreter::kernels::getTensorRuntimeShape
luci_interpreter::RuntimeShape getTensorRuntimeShape(const circle::Tensor *circle_tensor, BaseRuntimeGraph *runtime_graph)
Definition
Utils.cpp:29
luci_interpreter
Definition
BuddyMemoryManager.h:22
luci_interpreter::configure_kernel_CircleCast
void configure_kernel_CircleCast(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition
Cast.cpp:34
luci_interpreter::execute_kernel_CircleCast
void execute_kernel_CircleCast(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition
Cast.cpp:43
luci::must_cast
T must_cast(loco::Node *node)
Definition
CircleNodeDecl.h:95
Builders.h
onert-micro
luci-interpreter
src
kernels
Cast.cpp
Generated by
1.9.8