ONE - On-device Neural Engine
Loading...
Searching...
No Matches
tfkit::tf Namespace Reference

Functions

bool HasAttr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
tensorflow::DataType GetDataTypeAttr (const tensorflow::NodeDef &node, const std::string &attr_name)
 
tensorflow::TensorProto * GetTensorAttr (tensorflow::NodeDef &node, const std::string &attr_name)
 
int GetElementCount (const tensorflow::TensorShapeProto &)
 GetElementCount returns -1 for rank-0 tensor shape.
 

Function Documentation

◆ GetDataTypeAttr()

tensorflow::DataType tfkit::tf::GetDataTypeAttr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 58 of file Support.cpp.

59{
60 assert(HasAttr(node, attr_name));
61 const auto &attr = node.attr().at(attr_name);
62 assert(attr.value_case() == tensorflow::AttrValue::kType);
63 return attr.type();
64}
bool HasAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Support.cpp:53

References HasAttr().

◆ GetElementCount()

int tfkit::tf::GetElementCount ( const tensorflow::TensorShapeProto &  shape)

GetElementCount returns -1 for rank-0 tensor shape.

Definition at line 74 of file Support.cpp.

75{
76 int count = -1;
77
78 for (auto &d : shape.dim())
79 {
80 if (d.size() == 0)
81 {
82 count = 0;
83 break;
84 }
85 if (count == -1)
86 count = 1;
87
88 count *= d.size();
89 }
90 return count;
91}

◆ GetTensorAttr()

tensorflow::TensorProto * tfkit::tf::GetTensorAttr ( tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 66 of file Support.cpp.

67{
68 assert(HasAttr(node, attr_name));
69 tensorflow::AttrValue &attr = node.mutable_attr()->at(attr_name);
70 assert(attr.value_case() == tensorflow::AttrValue::kTensor);
71 return attr.mutable_tensor();
72}

References HasAttr().

◆ HasAttr()

bool tfkit::tf::HasAttr ( const tensorflow::NodeDef &  node,
const std::string &  attr_name 
)

Definition at line 53 of file Support.cpp.

54{
55 return node.attr().count(attr_name) > 0;
56}

Referenced by GetDataTypeAttr(), and GetTensorAttr().