ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Cast.cpp
Go to the documentation of this file.
1
/*
2
* Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4
*
5
* Licensed under the Apache License, Version 2.0 (the "License");
6
* you may not use this file except in compliance with the License.
7
* You may obtain a copy of the License at
8
*
9
* http://www.apache.org/licenses/LICENSE-2.0
10
*
11
* Unless required by applicable law or agreed to in writing, software
12
* distributed under the License is distributed on an "AS IS" BASIS,
13
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
* See the License for the specific language governing permissions and
15
* limitations under the License.
16
*/
17
18
#include "
kernels/Cast.h
"
19
#include "kernels/Utils.h"
20
21
namespace
22
{
23
24
using namespace
luci_interpreter
;
25
using namespace
luci_interpreter::kernels
;
26
27
template
<
typename
InT,
typename
OutT>
28
void
cast_data
(
const
InT
*
in_data
,
OutT
*
out_data
,
uint32_t
elements_count
)
29
{
30
std::transform(
in_data
,
in_data
+
elements_count
,
out_data
,
31
[](
InT
a) {
return
static_cast<
OutT
>
(a); });
32
}
33
34
template
<
typename
InT>
void
cast_from_pointer_to_tensor
(
const
InT
*
in_data
,
Tensor
*
out_tensor
)
35
{
36
auto
const
out_type =
out_tensor
->element_type();
37
auto
const
elements_count
=
out_tensor
->shape().num_elements();
38
39
switch
(out_type)
40
{
41
case
loco::DataType::U8:
42
cast_data
(
in_data
,
getTensorData<uint8_t>
(
out_tensor
),
elements_count
);
43
break
;
44
case
loco::DataType::U16:
45
cast_data
(
in_data
,
getTensorData<uint16_t>
(
out_tensor
),
elements_count
);
46
break
;
47
case
loco::DataType::U32:
48
cast_data
(
in_data
,
getTensorData<uint32_t>
(
out_tensor
),
elements_count
);
49
break
;
50
case
loco::DataType::U64:
51
cast_data
(
in_data
,
getTensorData<uint64_t>
(
out_tensor
),
elements_count
);
52
break
;
53
case
loco::DataType::S8:
54
cast_data
(
in_data
,
getTensorData<int8_t>
(
out_tensor
),
elements_count
);
55
break
;
56
case
loco::DataType::S16:
57
cast_data
(
in_data
,
getTensorData<int16_t>
(
out_tensor
),
elements_count
);
58
break
;
59
case
loco::DataType::S32:
60
cast_data
(
in_data
,
getTensorData<int32_t>
(
out_tensor
),
elements_count
);
61
break
;
62
case
loco::DataType::S64:
63
cast_data
(
in_data
,
getTensorData<int64_t>
(
out_tensor
),
elements_count
);
64
break
;
65
case
loco::DataType::FLOAT32:
66
cast_data
(
in_data
,
getTensorData<float>
(
out_tensor
),
elements_count
);
67
break
;
68
case
loco::DataType::BOOL:
69
cast_data
(
in_data
,
getTensorData<bool>
(
out_tensor
),
elements_count
);
70
break
;
71
default
:
72
throw
std::runtime_error(
"Unsupported output type."
);
73
}
74
}
75
76
void
cast_from_tensor_to_tensor
(
const
Tensor
*
in_tensor
,
Tensor
*
out_tensor
)
77
{
78
auto
in_type
=
in_tensor
->element_type();
79
80
switch
(
in_type
)
81
{
82
case
loco::DataType::U8:
83
cast_from_pointer_to_tensor
(
getTensorData<uint8_t>
(
in_tensor
),
out_tensor
);
84
break
;
85
case
loco::DataType::U16:
86
cast_from_pointer_to_tensor
(
getTensorData<uint16_t>
(
in_tensor
),
out_tensor
);
87
break
;
88
case
loco::DataType::U32:
89
cast_from_pointer_to_tensor
(
getTensorData<uint32_t>
(
in_tensor
),
out_tensor
);
90
break
;
91
case
loco::DataType::U64:
92
cast_from_pointer_to_tensor
(
getTensorData<uint64_t>
(
in_tensor
),
out_tensor
);
93
break
;
94
case
loco::DataType::S8:
95
cast_from_pointer_to_tensor
(
getTensorData<int8_t>
(
in_tensor
),
out_tensor
);
96
break
;
97
case
loco::DataType::S16:
98
cast_from_pointer_to_tensor
(
getTensorData<int16_t>
(
in_tensor
),
out_tensor
);
99
break
;
100
case
loco::DataType::S32:
101
cast_from_pointer_to_tensor
(
getTensorData<int32_t>
(
in_tensor
),
out_tensor
);
102
break
;
103
case
loco::DataType::S64:
104
cast_from_pointer_to_tensor
(
getTensorData<int64_t>
(
in_tensor
),
out_tensor
);
105
break
;
106
case
loco::DataType::FLOAT32:
107
cast_from_pointer_to_tensor
(
getTensorData<float>
(
in_tensor
),
out_tensor
);
108
break
;
109
case
loco::DataType::BOOL:
110
cast_from_pointer_to_tensor
(
getTensorData<bool>
(
in_tensor
),
out_tensor
);
111
break
;
112
default
:
113
throw
std::runtime_error(
"Unsupported input type."
);
114
}
115
}
116
117
}
// namespace
118
119
namespace
luci_interpreter
120
{
121
namespace
kernels
122
{
123
124
Cast::Cast
(
const
Tensor
*input,
Tensor
*output) :
Kernel
({
input
}, {output}) {}
125
126
void
Cast::configure
()
127
{
128
LUCI_INTERPRETER_CHECK
(
input
()->element_type() != loco::DataType::Unknown);
129
LUCI_INTERPRETER_CHECK
(
output
()->element_type() != loco::DataType::Unknown);
130
131
const
Shape
&shape =
input
()->
shape
();
132
output
()->
resize
(shape);
133
}
134
135
void
Cast::execute
()
const
136
{
137
assert(
input
()->shape().num_elements() ==
output
()->shape().num_elements());
138
139
cast_from_tensor_to_tensor
(
input
(),
output
());
140
}
141
142
}
// namespace kernels
143
}
// namespace luci_interpreter
luci_interpreter::Kernel
Definition
Kernel.h:29
luci_interpreter::Shape
Definition
Tensor.h:33
luci_interpreter::Tensor
Definition
Tensor.h:101
luci_interpreter::Tensor::resize
void resize(const Shape &new_shape)
Definition
Tensor.cpp:56
luci_interpreter::Tensor::shape
const Shape & shape() const
Definition
Tensor.h:107
luci_interpreter::kernels::Cast::Cast
Cast(const Tensor *input, Tensor *output)
Definition
Cast.cpp:124
luci_interpreter::kernels::Cast::configure
void configure() override
Definition
Cast.cpp:126
luci_interpreter::kernels::Cast::output
Tensor * output() const
Definition
Cast.h:34
luci_interpreter::kernels::Cast::input
const Tensor * input() const
Definition
Cast.h:33
luci_interpreter::kernels::Cast::execute
void execute() const override
Definition
Cast.cpp:135
LUCI_INTERPRETER_CHECK
#define LUCI_INTERPRETER_CHECK(cond)
Definition
Utils.h:36
Cast.h
luci_interpreter::kernels
Definition
Abs.cpp:26
luci_interpreter
Definition
BuddyMemoryManager.h:22
luci::must_cast
T must_cast(loco::Node *node)
Definition
CircleNodeDecl.h:95
compiler
luci-interpreter
src
kernels
Cast.cpp
Generated by
1.9.8