ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Convert.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include <plier/tf/Convert.h>
19
21
22#include <cassert>
23#include <stdexcept>
24
25namespace plier
26{
27namespace tf
28{
29
30bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
31{
32 return node.attr().count(attr_name) > 0;
33}
34
35bool has_attrs(const tensorflow::NodeDef &node, const std::vector<std::string> &attr_names)
36{
37 for (auto &attr : attr_names)
38 if (!has_attr(node, attr))
39 return false;
40 return true;
41}
42
43tensorflow::DataType get_datatype_attr(const tensorflow::NodeDef &node,
44 const std::string &attr_name)
45{
46 assert(has_attr(node, attr_name));
47 const auto &attr = node.attr().at(attr_name);
48 assert(attr.value_case() == tensorflow::AttrValue::kType);
49 return attr.type();
50}
51
52const tensorflow::TensorShapeProto &get_shape_attr(const tensorflow::NodeDef &node,
53 const std::string &attr_name)
54{
55 assert(has_attr(node, attr_name));
56 const auto &attr = node.attr().at(attr_name);
57 assert(attr.value_case() == tensorflow::AttrValue::kShape);
58 return attr.shape();
59}
60
61const tensorflow::TensorProto &get_tensor_attr(const tensorflow::NodeDef &node,
62 const std::string &attr_name)
63{
64 assert(has_attr(node, attr_name));
65 const auto &attr = node.attr().at(attr_name);
66 assert(attr.value_case() == tensorflow::AttrValue::kTensor);
67 return attr.tensor();
68}
69
70const ::tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef &node,
71 const std::string &attr_name)
72{
73 assert(has_attr(node, attr_name));
74 const auto &attr = node.attr().at(attr_name);
75 assert(attr.value_case() == tensorflow::AttrValue::kList);
76 return attr.list();
77}
78
79const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
80{
81 assert(has_attr(node, attr_name));
82 const auto &attr = node.attr().at(attr_name);
83 assert(attr.value_case() == tensorflow::AttrValue::kS);
84 return attr.s();
85}
86
87int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
88{
89 assert(has_attr(node, attr_name));
90 const auto &attr = node.attr().at(attr_name);
91 assert(attr.value_case() == tensorflow::AttrValue::kI);
92 return attr.i();
93}
94
95float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
96{
97 assert(has_attr(node, attr_name));
98 const auto &attr = node.attr().at(attr_name);
99 assert(attr.value_case() == tensorflow::AttrValue::kF);
100 return attr.f();
101}
102
103bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
104{
105 assert(has_attr(node, attr_name));
106 const auto &attr = node.attr().at(attr_name);
107 assert(attr.value_case() == tensorflow::AttrValue::kB);
108 return attr.b();
109}
110
111std::vector<int64_t> as_int64_list(const tensorflow::AttrValue_ListValue &lv)
112{
113 std::vector<int64_t> vi;
114 int isize = lv.i_size();
115
116 vi.resize(isize);
117 for (int i = 0; i < isize; ++i)
118 vi[i] = lv.i(i);
119
120 return vi;
121}
122
123loco::DataType as_loco_datatype(const tensorflow::DataType tf_dtype)
124{
125 switch (tf_dtype)
126 {
127 case tensorflow::DT_INT8:
128 return loco::DataType::S8;
129 case tensorflow::DT_UINT8:
130 return loco::DataType::U8;
131 case tensorflow::DT_FLOAT:
132 return loco::DataType::FLOAT32;
133 case tensorflow::DT_INT32:
134 return loco::DataType::S32;
135 case tensorflow::DT_INT64:
136 return loco::DataType::S64;
137 case tensorflow::DT_BOOL:
138 case tensorflow::DT_STRING:
139 case tensorflow::DT_COMPLEX64:
140 default:
141 break;
142 }
143 throw std::runtime_error{"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(tf_dtype)};
144}
145
146DataLayout as_data_layout(const std::string &tf_layout_str)
147{
148 if (tf_layout_str == "NHWC")
149 return DataLayout::NHWC;
150 else if (tf_layout_str == "NCHW")
151 return DataLayout::NCHW;
152 else
153 throw std::runtime_error("unknown data layout");
154}
155
156DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
157{
158 auto layout = get_string_attr(node, attr_name);
159
160 if (layout == "NHWC")
161 return DataLayout::NHWC;
162 else if (layout == "NCHW")
163 return DataLayout::NCHW;
164 else
165 throw std::runtime_error("unknown data layout");
166}
167
168void copy_shape(const tensorflow::TensorShapeProto &tf_shape,
170{
171 assert(!tf_shape.unknown_rank());
172
173 int64_t tf_rank = tf_shape.dim_size();
174 assert(tf_rank < std::numeric_limits<uint32_t>::max());
175
176 int32_t rank = static_cast<int32_t>(tf_rank);
177 to_shape.resize(rank);
178
179 for (int32_t d = 0; d < rank; d++)
180 {
181 int64_t dim_value = tf_shape.dim(d).size();
182 assert(dim_value < std::numeric_limits<uint32_t>::max());
183
184 if (dim_value >= 0LL)
185 {
186 uint32_t dim_value32 = static_cast<uint32_t>(dim_value);
187 to_shape.dim(d) = dim_value32;
188 }
189 else
190 {
191 throw std::runtime_error("Cannot handle unknown dimension");
192 // TODO support unknown dimension
193 }
194 }
195}
196
197} // namespace tf
198} // namespace plier
uint32_t & dim(uint32_t axis)
Definition Shape.cpp:42
Shape & resize(uint32_t size)
Definition Shape.cpp:36
DataType
"scalar" value type
Definition DataType.h:27
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35
bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:30
bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:103
const tensorflow::TensorProto & get_tensor_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:61
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:79
tensorflow::DataType get_datatype_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:43
loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
Definition Convert.cpp:123
DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:156
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
Definition Convert.cpp:111
float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:95
int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:87
const tensorflow::TensorShapeProto & get_shape_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:52
DataLayout
Class to represent TensorFlow "data_format" attr.
Definition Convert.h:57
DataLayout as_data_layout(const std::string &tf_layout_str)
@ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience
Definition Convert.cpp:146
void copy_shape(const tensorflow::TensorShapeProto &tf_shape, nncc::core::ADT::tensor::Shape &to_shape)
Copy shape defined in TensorShapeProto to angkor shape.
Definition Convert.cpp:168
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:70