ONE - On-device Neural Engine
Loading...
Searching...
No Matches
plier::tf Namespace Reference

Enumerations

enum class  DataLayout { NHWC , NCHW }
 Class to represent TensorFlow "data_format" attr. More...
 

Functions

bool has_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
bool has_attrs (const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
 
tensorflow::DataType get_datatype_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
const tensorflow::TensorShapeProto & get_shape_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
const tensorflow::TensorProto & get_tensor_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
const tensorflow::AttrValue_ListValue & get_list_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
const std::string & get_string_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
int64_t get_int_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
float get_float_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
bool get_bool_attr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
std::vector< int64_t > as_int64_list (const tensorflow::AttrValue_ListValue &lv)
 
loco::DataType as_loco_datatype (const tensorflow::DataType dtype)
 
DataLayout as_data_layout (const std::string &tf_layout_str)
 @ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience
 
DataLayout get_data_layout (const tensorflow::NodeDef &node, const std::string &attr_name)
 
void copy_shape (const tensorflow::TensorShapeProto &tf_shape, nncc::core::ADT::tensor::Shape &to_shape)
 Copy shape defined in TensorShapeProto to angkor shape.
 
bool parse_graphdef (char const *pbtxt, tensorflow::GraphDef &graphdef)
 
bool parse_nodedef (char const *pbtxt, tensorflow::NodeDef &nodedef)
 

Enumeration Type Documentation

◆ DataLayout

enum class plier::tf::DataLayout
strong

Class to represent TensorFlow "data_format" attr.

Enumerator
NHWC 
NCHW 

Definition at line 56 of file Convert.h.

57{
58 NHWC,
59 NCHW,
60};

Function Documentation

◆ as_data_layout()

DataLayout plier::tf::as_data_layout ( const std::string &  tf_layout_str)

@ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience

Definition at line 146 of file Convert.cpp.

147{
148 if (tf_layout_str == "NHWC")
149 return DataLayout::NHWC;
150 else if (tf_layout_str == "NCHW")
151 return DataLayout::NCHW;
152 else
153 throw std::runtime_error("unknown data layout");
154}

References NCHW, and NHWC.

◆ as_int64_list()

std::vector< int64_t > plier::tf::as_int64_list ( const tensorflow::AttrValue_ListValue &  lv)

◆ as_loco_datatype()

loco::DataType plier::tf::as_loco_datatype ( const tensorflow::DataType  dtype)

Definition at line 123 of file Convert.cpp.

124{
125 switch (tf_dtype)
126 {
127 case tensorflow::DT_INT8:
128 return loco::DataType::S8;
129 case tensorflow::DT_UINT8:
130 return loco::DataType::U8;
131 case tensorflow::DT_FLOAT:
132 return loco::DataType::FLOAT32;
133 case tensorflow::DT_INT32:
134 return loco::DataType::S32;
135 case tensorflow::DT_INT64:
136 return loco::DataType::S64;
137 case tensorflow::DT_BOOL:
138 case tensorflow::DT_STRING:
139 case tensorflow::DT_COMPLEX64:
140 default:
141 break;
142 }
143 throw std::runtime_error{"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(tf_dtype)};
144}

Referenced by moco::ConstGraphBuilder::build(), moco::PlaceholderGraphBuilder::build(), moco::ShapeGraphBuilder::build(), moco::ConstGraphBuilder::validate(), and moco::PlaceholderGraphBuilder::validate().

◆ copy_shape()

void plier::tf::copy_shape ( const tensorflow::TensorShapeProto &  tf_shape,
nncc::core::ADT::tensor::Shape to_shape 
)

Copy shape defined in TensorShapeProto to angkor shape.

Note
Unknown dimension is not supported

Definition at line 168 of file Convert.cpp.

170{
171 assert(!tf_shape.unknown_rank());
172
173 int64_t tf_rank = tf_shape.dim_size();
174 assert(tf_rank < std::numeric_limits<uint32_t>::max());
175
176 int32_t rank = static_cast<int32_t>(tf_rank);
177 to_shape.resize(rank);
178
179 for (int32_t d = 0; d < rank; d++)
180 {
181 int64_t dim_value = tf_shape.dim(d).size();
182 assert(dim_value < std::numeric_limits<uint32_t>::max());
183
184 if (dim_value >= 0LL)
185 {
186 uint32_t dim_value32 = static_cast<uint32_t>(dim_value);
187 to_shape.dim(d) = dim_value32;
188 }
189 else
190 {
191 throw std::runtime_error("Cannot handle unknown dimension");
192 // TODO support unknown dimension
193 }
194 }
195}
uint32_t & dim(uint32_t axis)
Definition Shape.cpp:42
Shape & resize(uint32_t size)
Definition Shape.cpp:36

References nncc::core::ADT::tensor::Shape::dim(), and nncc::core::ADT::tensor::Shape::resize().

◆ get_bool_attr()

bool plier::tf::get_bool_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 103 of file Convert.cpp.

104{
105 assert(has_attr(node, attr_name));
106 const auto &attr = node.attr().at(attr_name);
107 assert(attr.value_case() == tensorflow::AttrValue::kB);
108 return attr.b();
109}
bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:30

References has_attr().

Referenced by moco::FakeQuantWithMinMaxVarsGraphBuilder::build(), and moco::MeanGraphBuilder::build().

◆ get_data_layout()

DataLayout plier::tf::get_data_layout ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 156 of file Convert.cpp.

157{
158 auto layout = get_string_attr(node, attr_name);
159
160 if (layout == "NHWC")
161 return DataLayout::NHWC;
162 else if (layout == "NCHW")
163 return DataLayout::NCHW;
164 else
165 throw std::runtime_error("unknown data layout");
166}
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:79

References get_string_attr(), NCHW, and NHWC.

◆ get_datatype_attr()

tensorflow::DataType plier::tf::get_datatype_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 43 of file Convert.cpp.

45{
46 assert(has_attr(node, attr_name));
47 const auto &attr = node.attr().at(attr_name);
48 assert(attr.value_case() == tensorflow::AttrValue::kType);
49 return attr.type();
50}

References has_attr().

Referenced by moco::ConstGraphBuilder::build(), moco::PlaceholderGraphBuilder::build(), moco::ShapeGraphBuilder::build(), moco::ConstGraphBuilder::validate(), moco::MeanGraphBuilder::validate(), and moco::PlaceholderGraphBuilder::validate().

◆ get_float_attr()

float plier::tf::get_float_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 95 of file Convert.cpp.

96{
97 assert(has_attr(node, attr_name));
98 const auto &attr = node.attr().at(attr_name);
99 assert(attr.value_case() == tensorflow::AttrValue::kF);
100 return attr.f();
101}

References has_attr().

Referenced by moco::FusedBatchNormGraphBuilder::build().

◆ get_int_attr()

int64_t plier::tf::get_int_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 87 of file Convert.cpp.

88{
89 assert(has_attr(node, attr_name));
90 const auto &attr = node.attr().at(attr_name);
91 assert(attr.value_case() == tensorflow::AttrValue::kI);
92 return attr.i();
93}

References has_attr().

Referenced by moco::PackGraphBuilder::build(), moco::StridedSliceGraphBuilder::build(), moco::FakeQuantWithMinMaxVarsGraphBuilder::build(), moco::ConcatV2GraphBuilder::validate(), and moco::PackGraphBuilder::validate().

◆ get_list_attr()

const ::tensorflow::AttrValue_ListValue & plier::tf::get_list_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

◆ get_shape_attr()

const tensorflow::TensorShapeProto & plier::tf::get_shape_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 52 of file Convert.cpp.

54{
55 assert(has_attr(node, attr_name));
56 const auto &attr = node.attr().at(attr_name);
57 assert(attr.value_case() == tensorflow::AttrValue::kShape);
58 return attr.shape();
59}

References has_attr().

Referenced by moco::PlaceholderGraphBuilder::build().

◆ get_string_attr()

const std::string & plier::tf::get_string_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

◆ get_tensor_attr()

const tensorflow::TensorProto & plier::tf::get_tensor_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 61 of file Convert.cpp.

63{
64 assert(has_attr(node, attr_name));
65 const auto &attr = node.attr().at(attr_name);
66 assert(attr.value_case() == tensorflow::AttrValue::kTensor);
67 return attr.tensor();
68}

References has_attr().

Referenced by moco::ConstGraphBuilder::build(), and moco::ConstGraphBuilder::validate().

◆ has_attr()

bool plier::tf::has_attr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

◆ has_attrs()

◆ parse_graphdef()

bool plier::tf::parse_graphdef ( char const *  pbtxt,
tensorflow::GraphDef &  graphdef 
)

Definition at line 55 of file TestHelper.cpp.

56{
57 imemstream mempb(pbtxt, std::strlen(pbtxt));
58 google::protobuf::io::IstreamInputStream iis(&mempb);
59 return google::protobuf::TextFormat::Parse(&iis, &graphdef);
60}

◆ parse_nodedef()

bool plier::tf::parse_nodedef ( char const *  pbtxt,
tensorflow::NodeDef &  nodedef 
)

Definition at line 62 of file TestHelper.cpp.

63{
64 imemstream mempb(pbtxt, std::strlen(pbtxt));
65 google::protobuf::io::IstreamInputStream iis(&mempb);
66 return google::protobuf::TextFormat::Parse(&iis, &nodedef);
67}