ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Cast.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/Cast.h"
19#include "kernels/Utils.h"
20
21namespace
22{
23
24using namespace luci_interpreter;
25using namespace luci_interpreter::kernels;
26
27template <typename InT, typename OutT>
28void cast_data(const InT *in_data, OutT *out_data, uint32_t elements_count)
29{
30 std::transform(in_data, in_data + elements_count, out_data,
31 [](InT a) { return static_cast<OutT>(a); });
32}
33
34template <typename InT> void cast_from_pointer_to_tensor(const InT *in_data, Tensor *out_tensor)
35{
36 auto const out_type = out_tensor->element_type();
37 auto const elements_count = out_tensor->shape().num_elements();
38
39 switch (out_type)
40 {
41 case loco::DataType::U8:
42 cast_data(in_data, getTensorData<uint8_t>(out_tensor), elements_count);
43 break;
44 case loco::DataType::U16:
45 cast_data(in_data, getTensorData<uint16_t>(out_tensor), elements_count);
46 break;
47 case loco::DataType::U32:
48 cast_data(in_data, getTensorData<uint32_t>(out_tensor), elements_count);
49 break;
50 case loco::DataType::U64:
51 cast_data(in_data, getTensorData<uint64_t>(out_tensor), elements_count);
52 break;
53 case loco::DataType::S8:
54 cast_data(in_data, getTensorData<int8_t>(out_tensor), elements_count);
55 break;
56 case loco::DataType::S16:
57 cast_data(in_data, getTensorData<int16_t>(out_tensor), elements_count);
58 break;
59 case loco::DataType::S32:
60 cast_data(in_data, getTensorData<int32_t>(out_tensor), elements_count);
61 break;
62 case loco::DataType::S64:
63 cast_data(in_data, getTensorData<int64_t>(out_tensor), elements_count);
64 break;
65 case loco::DataType::FLOAT32:
66 cast_data(in_data, getTensorData<float>(out_tensor), elements_count);
67 break;
68 case loco::DataType::BOOL:
69 cast_data(in_data, getTensorData<bool>(out_tensor), elements_count);
70 break;
71 default:
72 throw std::runtime_error("Unsupported output type.");
73 }
74}
75
76void cast_from_tensor_to_tensor(const Tensor *in_tensor, Tensor *out_tensor)
77{
78 auto in_type = in_tensor->element_type();
79
80 switch (in_type)
81 {
82 case loco::DataType::U8:
83 cast_from_pointer_to_tensor(getTensorData<uint8_t>(in_tensor), out_tensor);
84 break;
85 case loco::DataType::U16:
86 cast_from_pointer_to_tensor(getTensorData<uint16_t>(in_tensor), out_tensor);
87 break;
88 case loco::DataType::U32:
89 cast_from_pointer_to_tensor(getTensorData<uint32_t>(in_tensor), out_tensor);
90 break;
91 case loco::DataType::U64:
92 cast_from_pointer_to_tensor(getTensorData<uint64_t>(in_tensor), out_tensor);
93 break;
94 case loco::DataType::S8:
95 cast_from_pointer_to_tensor(getTensorData<int8_t>(in_tensor), out_tensor);
96 break;
97 case loco::DataType::S16:
98 cast_from_pointer_to_tensor(getTensorData<int16_t>(in_tensor), out_tensor);
99 break;
100 case loco::DataType::S32:
101 cast_from_pointer_to_tensor(getTensorData<int32_t>(in_tensor), out_tensor);
102 break;
103 case loco::DataType::S64:
104 cast_from_pointer_to_tensor(getTensorData<int64_t>(in_tensor), out_tensor);
105 break;
106 case loco::DataType::FLOAT32:
107 cast_from_pointer_to_tensor(getTensorData<float>(in_tensor), out_tensor);
108 break;
109 case loco::DataType::BOOL:
110 cast_from_pointer_to_tensor(getTensorData<bool>(in_tensor), out_tensor);
111 break;
112 default:
113 throw std::runtime_error("Unsupported input type.");
114 }
115}
116
117} // namespace
118
119namespace luci_interpreter
120{
121namespace kernels
122{
123
124Cast::Cast(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
125
127{
128 LUCI_INTERPRETER_CHECK(input()->element_type() != loco::DataType::Unknown);
129 LUCI_INTERPRETER_CHECK(output()->element_type() != loco::DataType::Unknown);
130
131 const Shape &shape = input()->shape();
132 output()->resize(shape);
133}
134
135void Cast::execute() const
136{
137 assert(input()->shape().num_elements() == output()->shape().num_elements());
138
139 cast_from_tensor_to_tensor(input(), output());
140}
141
142} // namespace kernels
143} // namespace luci_interpreter
int32_t num_elements() const
Definition Tensor.h:53
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
DataType element_type() const
Definition Tensor.h:105
Cast(const Tensor *input, Tensor *output)
Definition Cast.cpp:124
void configure() override
Definition Cast.cpp:126
Tensor * output() const
Definition Cast.h:34
const Tensor * input() const
Definition Cast.h:33
void execute() const override
Definition Cast.cpp:135
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36