30bool has_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
32 return node.attr().count(attr_name) > 0;
35bool has_attrs(
const tensorflow::NodeDef &node,
const std::vector<std::string> &attr_names)
37 for (
auto &attr : attr_names)
44 const std::string &attr_name)
47 const auto &attr = node.attr().at(attr_name);
48 assert(attr.value_case() == tensorflow::AttrValue::kType);
52const tensorflow::TensorShapeProto &
get_shape_attr(
const tensorflow::NodeDef &node,
53 const std::string &attr_name)
56 const auto &attr = node.attr().at(attr_name);
57 assert(attr.value_case() == tensorflow::AttrValue::kShape);
62 const std::string &attr_name)
65 const auto &attr = node.attr().at(attr_name);
66 assert(attr.value_case() == tensorflow::AttrValue::kTensor);
70const ::tensorflow::AttrValue_ListValue &
get_list_attr(
const tensorflow::NodeDef &node,
71 const std::string &attr_name)
74 const auto &attr = node.attr().at(attr_name);
75 assert(attr.value_case() == tensorflow::AttrValue::kList);
79const std::string &
get_string_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
82 const auto &attr = node.attr().at(attr_name);
83 assert(attr.value_case() == tensorflow::AttrValue::kS);
87int64_t
get_int_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
90 const auto &attr = node.attr().at(attr_name);
91 assert(attr.value_case() == tensorflow::AttrValue::kI);
95float get_float_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
98 const auto &attr = node.attr().at(attr_name);
99 assert(attr.value_case() == tensorflow::AttrValue::kF);
103bool get_bool_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
106 const auto &attr = node.attr().at(attr_name);
107 assert(attr.value_case() == tensorflow::AttrValue::kB);
111std::vector<int64_t>
as_int64_list(
const tensorflow::AttrValue_ListValue &lv)
113 std::vector<int64_t> vi;
114 int isize = lv.i_size();
117 for (
int i = 0; i < isize; ++i)
127 case tensorflow::DT_INT8:
128 return loco::DataType::S8;
129 case tensorflow::DT_UINT8:
130 return loco::DataType::U8;
131 case tensorflow::DT_FLOAT:
132 return loco::DataType::FLOAT32;
133 case tensorflow::DT_INT32:
134 return loco::DataType::S32;
135 case tensorflow::DT_INT64:
136 return loco::DataType::S64;
137 case tensorflow::DT_BOOL:
138 case tensorflow::DT_STRING:
139 case tensorflow::DT_COMPLEX64:
143 throw std::runtime_error{
"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(tf_dtype)};
148 if (tf_layout_str ==
"NHWC")
150 else if (tf_layout_str ==
"NCHW")
153 throw std::runtime_error(
"unknown data layout");
160 if (layout ==
"NHWC")
162 else if (layout ==
"NCHW")
165 throw std::runtime_error(
"unknown data layout");
171 assert(!tf_shape.unknown_rank());
173 int64_t tf_rank = tf_shape.dim_size();
174 assert(tf_rank < std::numeric_limits<uint32_t>::max());
176 int32_t rank =
static_cast<int32_t
>(tf_rank);
179 for (int32_t d = 0; d < rank; d++)
181 int64_t dim_value = tf_shape.
dim(d).size();
182 assert(dim_value < std::numeric_limits<uint32_t>::max());
184 if (dim_value >= 0LL)
186 uint32_t dim_value32 =
static_cast<uint32_t
>(dim_value);
187 to_shape.
dim(d) = dim_value32;
191 throw std::runtime_error(
"Cannot handle unknown dimension");
uint32_t & dim(uint32_t axis)
Shape & resize(uint32_t size)
DataType
"scalar" value type
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
const tensorflow::TensorProto & get_tensor_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
tensorflow::DataType get_datatype_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
const tensorflow::TensorShapeProto & get_shape_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
DataLayout
Class to represent TensorFlow "data_format" attr.
DataLayout as_data_layout(const std::string &tf_layout_str)
@ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience
void copy_shape(const tensorflow::TensorShapeProto &tf_shape, nncc::core::ADT::tensor::Shape &to_shape)
Copy shape defined in TensorShapeProto to angkor shape.
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)