ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Convert.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include <plier/tf/Convert.h>
19
21
22#include <cassert>
23#include <stdexcept>
24
25namespace plier
26{
27namespace tf
28{
29
30bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
31{
32 return node.attr().count(attr_name) > 0;
33}
34
35bool has_attrs(const tensorflow::NodeDef &node, const std::vector<std::string> &attr_names)
36{
37 for (auto &attr : attr_names)
38 if (!has_attr(node, attr))
39 return false;
40 return true;
41}
42
43tensorflow::DataType get_datatype_attr(const tensorflow::NodeDef &node,
44 const std::string &attr_name)
45{
46 assert(has_attr(node, attr_name));
47 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kType);
48 return node.attr().at(attr_name).type();
49}
50
51const tensorflow::TensorShapeProto &get_shape_attr(const tensorflow::NodeDef &node,
52 const std::string &attr_name)
53{
54 assert(has_attr(node, attr_name));
55 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kShape);
56 return node.attr().at(attr_name).shape();
57}
58
59const tensorflow::TensorProto &get_tensor_attr(const tensorflow::NodeDef &node,
60 const std::string &attr_name)
61{
62 assert(has_attr(node, attr_name));
63 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kTensor);
64 return node.attr().at(attr_name).tensor();
65}
66
67const ::tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef &node,
68 const std::string &attr_name)
69{
70 assert(has_attr(node, attr_name));
71 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kList);
72 return node.attr().at(attr_name).list();
73}
74
75const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
76{
77 assert(has_attr(node, attr_name));
78 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kS);
79 return node.attr().at(attr_name).s();
80}
81
82int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
83{
84 assert(has_attr(node, attr_name));
85 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kI);
86 return node.attr().at(attr_name).i();
87}
88
89float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
90{
91 assert(has_attr(node, attr_name));
92 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kF);
93 return node.attr().at(attr_name).f();
94}
95
96bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
97{
98 assert(has_attr(node, attr_name));
99 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kB);
100 return node.attr().at(attr_name).b();
101}
102
103std::vector<int64_t> as_int64_list(const tensorflow::AttrValue_ListValue &lv)
104{
105 std::vector<int64_t> vi;
106 int isize = lv.i_size();
107
108 vi.resize(isize);
109 for (int i = 0; i < isize; ++i)
110 vi[i] = lv.i(i);
111
112 return vi;
113}
114
115loco::DataType as_loco_datatype(const tensorflow::DataType tf_dtype)
116{
117 switch (tf_dtype)
118 {
119 case tensorflow::DT_INT8:
120 return loco::DataType::S8;
121 case tensorflow::DT_UINT8:
122 return loco::DataType::U8;
123 case tensorflow::DT_FLOAT:
124 return loco::DataType::FLOAT32;
125 case tensorflow::DT_INT32:
126 return loco::DataType::S32;
127 case tensorflow::DT_INT64:
128 return loco::DataType::S64;
129 case tensorflow::DT_BOOL:
130 case tensorflow::DT_STRING:
131 case tensorflow::DT_COMPLEX64:
132 default:
133 break;
134 }
135 throw std::runtime_error{"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(tf_dtype)};
136}
137
138DataLayout as_data_layout(const std::string &tf_layout_str)
139{
140 if (tf_layout_str == "NHWC")
141 return DataLayout::NHWC;
142 else if (tf_layout_str == "NCHW")
143 return DataLayout::NCHW;
144 else
145 throw std::runtime_error("unknown data layout");
146}
147
148DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
149{
150 auto layout = get_string_attr(node, attr_name);
151
152 if (layout == "NHWC")
153 return DataLayout::NHWC;
154 else if (layout == "NCHW")
155 return DataLayout::NCHW;
156 else
157 throw std::runtime_error("unknown data layout");
158}
159
160void copy_shape(const tensorflow::TensorShapeProto &tf_shape,
162{
163 assert(!tf_shape.unknown_rank());
164
165 int64_t tf_rank = tf_shape.dim_size();
166 assert(tf_rank < std::numeric_limits<uint32_t>::max());
167
168 int32_t rank = static_cast<int32_t>(tf_rank);
169 to_shape.resize(rank);
170
171 for (int32_t d = 0; d < rank; d++)
172 {
173 int64_t dim_value = tf_shape.dim(d).size();
174 assert(dim_value < std::numeric_limits<uint32_t>::max());
175
176 if (dim_value >= 0LL)
177 {
178 uint32_t dim_value32 = static_cast<uint32_t>(dim_value);
179 to_shape.dim(d) = dim_value32;
180 }
181 else
182 {
183 throw std::runtime_error("Cannot handle unknown dimension");
184 // TODO support unknown dimension
185 }
186 }
187}
188
189} // namespace tf
190} // namespace plier
uint32_t & dim(uint32_t axis)
Definition Shape.cpp:42
Shape & resize(uint32_t size)
Definition Shape.cpp:36
DataType
"scalar" value type
Definition DataType.h:27
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35
bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:30
bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:96
const tensorflow::TensorProto & get_tensor_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:59
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:75
tensorflow::DataType get_datatype_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:43
loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
Definition Convert.cpp:115
DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:148
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
Definition Convert.cpp:103
float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:89
int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:82
const tensorflow::TensorShapeProto & get_shape_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:51
DataLayout
Class to represent TensorFlow "data_format" attr.
Definition Convert.h:57
DataLayout as_data_layout(const std::string &tf_layout_str)
@ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience
Definition Convert.cpp:138
void copy_shape(const tensorflow::TensorShapeProto &tf_shape, nncc::core::ADT::tensor::Shape &to_shape)
Copy shape defined in TensorShapeProto to angkor shape.
Definition Convert.cpp:160
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:67