30bool has_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
32 return node.attr().count(attr_name) > 0;
35bool has_attrs(
const tensorflow::NodeDef &node,
const std::vector<std::string> &attr_names)
37 for (
auto &attr : attr_names)
44 const std::string &attr_name)
47 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kType);
48 return node.attr().at(attr_name).type();
51const tensorflow::TensorShapeProto &
get_shape_attr(
const tensorflow::NodeDef &node,
52 const std::string &attr_name)
55 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kShape);
56 return node.attr().at(attr_name).shape();
60 const std::string &attr_name)
63 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kTensor);
64 return node.attr().at(attr_name).tensor();
67const ::tensorflow::AttrValue_ListValue &
get_list_attr(
const tensorflow::NodeDef &node,
68 const std::string &attr_name)
71 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kList);
72 return node.attr().at(attr_name).list();
75const std::string &
get_string_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
78 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kS);
79 return node.attr().at(attr_name).s();
82int64_t
get_int_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
85 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kI);
86 return node.attr().at(attr_name).i();
89float get_float_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
92 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kF);
93 return node.attr().at(attr_name).f();
96bool get_bool_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
99 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kB);
100 return node.attr().at(attr_name).b();
103std::vector<int64_t>
as_int64_list(
const tensorflow::AttrValue_ListValue &lv)
105 std::vector<int64_t> vi;
106 int isize = lv.i_size();
109 for (
int i = 0; i < isize; ++i)
119 case tensorflow::DT_INT8:
120 return loco::DataType::S8;
121 case tensorflow::DT_UINT8:
122 return loco::DataType::U8;
123 case tensorflow::DT_FLOAT:
124 return loco::DataType::FLOAT32;
125 case tensorflow::DT_INT32:
126 return loco::DataType::S32;
127 case tensorflow::DT_INT64:
128 return loco::DataType::S64;
129 case tensorflow::DT_BOOL:
130 case tensorflow::DT_STRING:
131 case tensorflow::DT_COMPLEX64:
135 throw std::runtime_error{
"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(tf_dtype)};
140 if (tf_layout_str ==
"NHWC")
142 else if (tf_layout_str ==
"NCHW")
145 throw std::runtime_error(
"unknown data layout");
152 if (layout ==
"NHWC")
154 else if (layout ==
"NCHW")
157 throw std::runtime_error(
"unknown data layout");
163 assert(!tf_shape.unknown_rank());
165 int64_t tf_rank = tf_shape.dim_size();
166 assert(tf_rank < std::numeric_limits<uint32_t>::max());
168 int32_t rank =
static_cast<int32_t
>(tf_rank);
171 for (int32_t d = 0; d < rank; d++)
173 int64_t dim_value = tf_shape.
dim(d).size();
174 assert(dim_value < std::numeric_limits<uint32_t>::max());
176 if (dim_value >= 0LL)
178 uint32_t dim_value32 =
static_cast<uint32_t
>(dim_value);
179 to_shape.
dim(d) = dim_value32;
183 throw std::runtime_error(
"Cannot handle unknown dimension");
uint32_t & dim(uint32_t axis)
Shape & resize(uint32_t size)
DataType
"scalar" value type
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
const tensorflow::TensorProto & get_tensor_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
tensorflow::DataType get_datatype_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
const tensorflow::TensorShapeProto & get_shape_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
DataLayout
Class to represent TensorFlow "data_format" attr.
DataLayout as_data_layout(const std::string &tf_layout_str)
@ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience
void copy_shape(const tensorflow::TensorShapeProto &tf_shape, nncc::core::ADT::tensor::Shape &to_shape)
Copy shape defined in TensorShapeProto to angkor shape.
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)