ONE - On-device Neural Engine
Loading...
Searching...
No Matches
plier::tf Namespace Reference

Enumerations

enum class  DataLayout { NHWC , NCHW }
 Class to represent TensorFlow "data_format" attr. More...
 

Functions

bool has_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
bool has_attrs (const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
 
tensorflow::DataType get_datatype_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
const tensorflow::TensorShapeProto & get_shape_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
const tensorflow::TensorProto & get_tensor_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
const tensorflow::AttrValue_ListValue & get_list_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
const std::string & get_string_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
int64_t get_int_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
float get_float_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
bool get_bool_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
std::vector< int64_t > as_int64_list (const tensorflow::AttrValue_ListValue &lv)
 
loco::DataType as_loco_datatype (const tensorflow::DataType dtype)
 
DataLayout as_data_layout (const std::string &tf_layout_str)
 @ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience
 
DataLayout get_data_layout (const tensorflow::NodeDef &node, const std::string &attr_name)
 
void copy_shape (const tensorflow::TensorShapeProto &tf_shape, nncc::core::ADT::tensor::Shape &to_shape)
 Copy shape defined in TensorShapeProto to angkor shape.
 
bool parse_graphdef (char const *pbtxt, tensorflow::GraphDef &graphdef)
 
bool parse_nodedef (char const *pbtxt, tensorflow::NodeDef &nodedef)
 

Enumeration Type Documentation

◆ DataLayout

enum class plier::tf::DataLayout
strong

Class to represent TensorFlow "data_format" attr.

Enumerator
NHWC 
NCHW 

Definition at line 56 of file Convert.h.

57{
58 NHWC,
59 NCHW,
60};

Function Documentation

◆ as_data_layout()

DataLayout plier::tf::as_data_layout ( const std::string &  tf_layout_str)

@ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience

Definition at line 138 of file Convert.cpp.

139{
140 if (tf_layout_str == "NHWC")
141 return DataLayout::NHWC;
142 else if (tf_layout_str == "NCHW")
143 return DataLayout::NCHW;
144 else
145 throw std::runtime_error("unknown data layout");
146}

References NCHW, and NHWC.

◆ as_int64_list()

std::vector< int64_t > plier::tf::as_int64_list ( const tensorflow::AttrValue_ListValue &  lv)

◆ as_loco_datatype()

loco::DataType plier::tf::as_loco_datatype ( const tensorflow::DataType  dtype)

Definition at line 115 of file Convert.cpp.

116{
117 switch (tf_dtype)
118 {
119 case tensorflow::DT_INT8:
120 return loco::DataType::S8;
121 case tensorflow::DT_UINT8:
122 return loco::DataType::U8;
123 case tensorflow::DT_FLOAT:
124 return loco::DataType::FLOAT32;
125 case tensorflow::DT_INT32:
126 return loco::DataType::S32;
127 case tensorflow::DT_INT64:
128 return loco::DataType::S64;
129 case tensorflow::DT_BOOL:
130 case tensorflow::DT_STRING:
131 case tensorflow::DT_COMPLEX64:
132 default:
133 break;
134 }
135 throw std::runtime_error{"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(tf_dtype)};
136}

Referenced by moco::ConstGraphBuilder::build(), moco::PlaceholderGraphBuilder::build(), moco::ShapeGraphBuilder::build(), moco::ConstGraphBuilder::validate(), and moco::PlaceholderGraphBuilder::validate().

◆ copy_shape()

void plier::tf::copy_shape ( const tensorflow::TensorShapeProto &  tf_shape,
nncc::core::ADT::tensor::Shape to_shape 
)

Copy shape defined in TensorShapeProto to angkor shape.

Note
Unknown dimension is not supported

Definition at line 160 of file Convert.cpp.

162{
163 assert(!tf_shape.unknown_rank());
164
165 int64_t tf_rank = tf_shape.dim_size();
166 assert(tf_rank < std::numeric_limits<uint32_t>::max());
167
168 int32_t rank = static_cast<int32_t>(tf_rank);
169 to_shape.resize(rank);
170
171 for (int32_t d = 0; d < rank; d++)
172 {
173 int64_t dim_value = tf_shape.dim(d).size();
174 assert(dim_value < std::numeric_limits<uint32_t>::max());
175
176 if (dim_value >= 0LL)
177 {
178 uint32_t dim_value32 = static_cast<uint32_t>(dim_value);
179 to_shape.dim(d) = dim_value32;
180 }
181 else
182 {
183 throw std::runtime_error("Cannot handle unknown dimension");
184 // TODO support unknown dimension
185 }
186 }
187}
uint32_t & dim(uint32_t axis)
Definition Shape.cpp:42
Shape & resize(uint32_t size)
Definition Shape.cpp:36

References nncc::core::ADT::tensor::Shape::dim(), and nncc::core::ADT::tensor::Shape::resize().

◆ get_bool_attr()

bool plier::tf::get_bool_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 96 of file Convert.cpp.

97{
98 assert(has_attr(node, attr_name));
99 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kB);
100 return node.attr().at(attr_name).b();
101}
bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:30

References has_attr().

Referenced by moco::FakeQuantWithMinMaxVarsGraphBuilder::build(), and moco::MeanGraphBuilder::build().

◆ get_data_layout()

DataLayout plier::tf::get_data_layout ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 148 of file Convert.cpp.

149{
150 auto layout = get_string_attr(node, attr_name);
151
152 if (layout == "NHWC")
153 return DataLayout::NHWC;
154 else if (layout == "NCHW")
155 return DataLayout::NCHW;
156 else
157 throw std::runtime_error("unknown data layout");
158}
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:75

References get_string_attr(), NCHW, and NHWC.

◆ get_datatype_attr()

tensorflow::DataType plier::tf::get_datatype_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 43 of file Convert.cpp.

45{
46 assert(has_attr(node, attr_name));
47 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kType);
48 return node.attr().at(attr_name).type();
49}

References has_attr().

Referenced by moco::ConstGraphBuilder::build(), moco::PlaceholderGraphBuilder::build(), moco::ShapeGraphBuilder::build(), moco::ConstGraphBuilder::validate(), moco::MeanGraphBuilder::validate(), and moco::PlaceholderGraphBuilder::validate().

◆ get_float_attr()

float plier::tf::get_float_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 89 of file Convert.cpp.

90{
91 assert(has_attr(node, attr_name));
92 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kF);
93 return node.attr().at(attr_name).f();
94}

References has_attr().

Referenced by moco::FusedBatchNormGraphBuilder::build().

◆ get_int_attr()

int64_t plier::tf::get_int_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 82 of file Convert.cpp.

83{
84 assert(has_attr(node, attr_name));
85 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kI);
86 return node.attr().at(attr_name).i();
87}

References has_attr().

Referenced by moco::PackGraphBuilder::build(), moco::StridedSliceGraphBuilder::build(), moco::FakeQuantWithMinMaxVarsGraphBuilder::build(), moco::ConcatV2GraphBuilder::validate(), and moco::PackGraphBuilder::validate().

◆ get_list_attr()

const ::tensorflow::AttrValue_ListValue & plier::tf::get_list_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

◆ get_shape_attr()

const tensorflow::TensorShapeProto & plier::tf::get_shape_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 51 of file Convert.cpp.

53{
54 assert(has_attr(node, attr_name));
55 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kShape);
56 return node.attr().at(attr_name).shape();
57}

References has_attr().

Referenced by moco::PlaceholderGraphBuilder::build().

◆ get_string_attr()

const std::string & plier::tf::get_string_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

◆ get_tensor_attr()

const tensorflow::TensorProto & plier::tf::get_tensor_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 59 of file Convert.cpp.

61{
62 assert(has_attr(node, attr_name));
63 assert(node.attr().at(attr_name).value_case() == tensorflow::AttrValue::kTensor);
64 return node.attr().at(attr_name).tensor();
65}

References has_attr().

Referenced by moco::ConstGraphBuilder::build(), and moco::ConstGraphBuilder::validate().

◆ has_attr()

bool plier::tf::has_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

◆ has_attrs()

◆ parse_graphdef()

bool plier::tf::parse_graphdef ( char const *  pbtxt,
tensorflow::GraphDef &  graphdef 
)

Definition at line 55 of file TestHelper.cpp.

56{
57 imemstream mempb(pbtxt, std::strlen(pbtxt));
58 google::protobuf::io::IstreamInputStream iis(&mempb);
59 return google::protobuf::TextFormat::Parse(&iis, &graphdef);
60}

◆ parse_nodedef()

bool plier::tf::parse_nodedef ( char const *  pbtxt,
tensorflow::NodeDef &  nodedef 
)

Definition at line 62 of file TestHelper.cpp.

63{
64 imemstream mempb(pbtxt, std::strlen(pbtxt));
65 google::protobuf::io::IstreamInputStream iis(&mempb);
66 return google::protobuf::TextFormat::Parse(&iis, &nodedef);
67}