ONE - On-device Neural Engine
|
Enumerations | |
enum class | DataLayout { NHWC , NCHW } |
Class to represent TensorFlow "data_format" attr. More... | |
Functions | |
bool | has_attr (const tensorflow::NodeDef &node, const std::string &attr_name) |
bool | has_attrs (const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names) |
tensorflow::DataType | get_datatype_attr (const tensorflow::NodeDef &node, const std::string &attr_name) |
const tensorflow::TensorShapeProto & | get_shape_attr (const tensorflow::NodeDef &node, const std::string &attr_name) |
const tensorflow::TensorProto & | get_tensor_attr (const tensorflow::NodeDef &node, const std::string &attr_name) |
const tensorflow::AttrValue_ListValue & | get_list_attr (const tensorflow::NodeDef &node, const std::string &attr_name) |
const std::string & | get_string_attr (const tensorflow::NodeDef &node, const std::string &attr_name) |
int64_t | get_int_attr (const tensorflow::NodeDef &node, const std::string &attr_name) |
float | get_float_attr (const tensorflow::NodeDef &node, const std::string &attr_name) |
bool | get_bool_attr (const tensorflow::NodeDef &node, const std::string &attr_name) |
std::vector< int64_t > | as_int64_list (const tensorflow::AttrValue_ListValue &lv) |
loco::DataType | as_loco_datatype (const tensorflow::DataType dtype) |
DataLayout | as_data_layout (const std::string &tf_layout_str) |
@ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience | |
DataLayout | get_data_layout (const tensorflow::NodeDef &node, const std::string &attr_name) |
void | copy_shape (const tensorflow::TensorShapeProto &tf_shape, nncc::core::ADT::tensor::Shape &to_shape) |
Copy shape defined in TensorShapeProto to angkor shape. | |
bool | parse_graphdef (char const *pbtxt, tensorflow::GraphDef &graphdef) |
bool | parse_nodedef (char const *pbtxt, tensorflow::NodeDef &nodedef) |
|
strong |
DataLayout plier::tf::as_data_layout | ( | const std::string & | tf_layout_str | ) |
@ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience
Definition at line 146 of file Convert.cpp.
std::vector< int64_t > plier::tf::as_int64_list | ( | const tensorflow::AttrValue_ListValue & | lv | ) |
Definition at line 111 of file Convert.cpp.
Referenced by moco::AvgPoolGraphBuilder::build(), moco::Conv2DGraphBuilder::build(), moco::MaxPoolGraphBuilder::build(), moco::Conv2DBackpropInputGraphBuilder::build(), moco::DepthwiseConv2dNativeGraphBuilder::build(), moco::SqueezeGraphBuilder::build(), moco::AvgPoolGraphBuilder::validate(), moco::MaxPoolGraphBuilder::validate(), and moco::DepthwiseConv2dNativeGraphBuilder::validate().
loco::DataType plier::tf::as_loco_datatype | ( | const tensorflow::DataType | dtype | ) |
Definition at line 123 of file Convert.cpp.
Referenced by moco::ConstGraphBuilder::build(), moco::PlaceholderGraphBuilder::build(), moco::ShapeGraphBuilder::build(), moco::ConstGraphBuilder::validate(), and moco::PlaceholderGraphBuilder::validate().
void plier::tf::copy_shape | ( | const tensorflow::TensorShapeProto & | tf_shape, |
nncc::core::ADT::tensor::Shape & | to_shape | ||
) |
Copy shape defined in TensorShapeProto to angkor shape.
Definition at line 168 of file Convert.cpp.
References nncc::core::ADT::tensor::Shape::dim(), and nncc::core::ADT::tensor::Shape::resize().
bool plier::tf::get_bool_attr | ( | const tensorflow::NodeDef & | node, |
const std::string & | attr_name | ||
) |
Definition at line 103 of file Convert.cpp.
References has_attr().
Referenced by moco::FakeQuantWithMinMaxVarsGraphBuilder::build(), and moco::MeanGraphBuilder::build().
DataLayout plier::tf::get_data_layout | ( | const tensorflow::NodeDef & | node, |
const std::string & | attr_name | ||
) |
Definition at line 156 of file Convert.cpp.
References get_string_attr(), NCHW, and NHWC.
tensorflow::DataType plier::tf::get_datatype_attr | ( | const tensorflow::NodeDef & | node, |
const std::string & | attr_name | ||
) |
Definition at line 43 of file Convert.cpp.
References has_attr().
Referenced by moco::ConstGraphBuilder::build(), moco::PlaceholderGraphBuilder::build(), moco::ShapeGraphBuilder::build(), moco::ConstGraphBuilder::validate(), moco::MeanGraphBuilder::validate(), and moco::PlaceholderGraphBuilder::validate().
float plier::tf::get_float_attr | ( | const tensorflow::NodeDef & | node, |
const std::string & | attr_name | ||
) |
Definition at line 95 of file Convert.cpp.
References has_attr().
Referenced by moco::FusedBatchNormGraphBuilder::build().
int64_t plier::tf::get_int_attr | ( | const tensorflow::NodeDef & | node, |
const std::string & | attr_name | ||
) |
Definition at line 87 of file Convert.cpp.
References has_attr().
Referenced by moco::PackGraphBuilder::build(), moco::StridedSliceGraphBuilder::build(), moco::FakeQuantWithMinMaxVarsGraphBuilder::build(), moco::ConcatV2GraphBuilder::validate(), and moco::PackGraphBuilder::validate().
const ::tensorflow::AttrValue_ListValue & plier::tf::get_list_attr | ( | const tensorflow::NodeDef & | node, |
const std::string & | attr_name | ||
) |
Definition at line 70 of file Convert.cpp.
References has_attr().
Referenced by moco::AvgPoolGraphBuilder::build(), moco::Conv2DGraphBuilder::build(), moco::MaxPoolGraphBuilder::build(), moco::Conv2DBackpropInputGraphBuilder::build(), moco::DepthwiseConv2dNativeGraphBuilder::build(), moco::SqueezeGraphBuilder::build(), moco::AvgPoolGraphBuilder::validate(), moco::Conv2DGraphBuilder::validate(), moco::MaxPoolGraphBuilder::validate(), moco::Conv2DBackpropInputGraphBuilder::validate(), and moco::DepthwiseConv2dNativeGraphBuilder::validate().
const tensorflow::TensorShapeProto & plier::tf::get_shape_attr | ( | const tensorflow::NodeDef & | node, |
const std::string & | attr_name | ||
) |
Definition at line 52 of file Convert.cpp.
References has_attr().
Referenced by moco::PlaceholderGraphBuilder::build().
const std::string & plier::tf::get_string_attr | ( | const tensorflow::NodeDef & | node, |
const std::string & | attr_name | ||
) |
Definition at line 79 of file Convert.cpp.
References has_attr().
Referenced by moco::AvgPoolGraphBuilder::build(), moco::BiasAddGraphBuilder::build(), moco::Conv2DGraphBuilder::build(), moco::MaxPoolGraphBuilder::build(), moco::Conv2DBackpropInputGraphBuilder::build(), moco::DepthwiseConv2dNativeGraphBuilder::build(), get_data_layout(), moco::BiasAddGraphBuilder::validate(), moco::Conv2DGraphBuilder::validate(), moco::MaxPoolGraphBuilder::validate(), moco::Conv2DBackpropInputGraphBuilder::validate(), and moco::DepthwiseConv2dNativeGraphBuilder::validate().
const tensorflow::TensorProto & plier::tf::get_tensor_attr | ( | const tensorflow::NodeDef & | node, |
const std::string & | attr_name | ||
) |
Definition at line 61 of file Convert.cpp.
References has_attr().
Referenced by moco::ConstGraphBuilder::build(), and moco::ConstGraphBuilder::validate().
bool plier::tf::has_attr | ( | const tensorflow::NodeDef & | node, |
const std::string & | attr_name | ||
) |
Definition at line 30 of file Convert.cpp.
Referenced by moco::FakeQuantWithMinMaxVarsGraphBuilder::build(), get_bool_attr(), get_datatype_attr(), get_float_attr(), get_int_attr(), get_list_attr(), get_shape_attr(), get_string_attr(), get_tensor_attr(), has_attrs(), moco::Conv2DGraphBuilder::validate(), and moco::Conv2DBackpropInputGraphBuilder::validate().
bool plier::tf::has_attrs | ( | const tensorflow::NodeDef & | node, |
const std::vector< std::string > & | attr_names | ||
) |
Definition at line 35 of file Convert.cpp.
References has_attr().
Referenced by moco::ShapeGraphBuilder::build(), moco::SqueezeGraphBuilder::build(), moco::AvgPoolGraphBuilder::validate(), moco::BiasAddGraphBuilder::validate(), moco::ConcatV2GraphBuilder::validate(), moco::ConstGraphBuilder::validate(), moco::Conv2DGraphBuilder::validate(), moco::MaxPoolGraphBuilder::validate(), moco::PackGraphBuilder::validate(), moco::StridedSliceGraphBuilder::validate(), moco::Conv2DBackpropInputGraphBuilder::validate(), moco::DepthwiseConv2dNativeGraphBuilder::validate(), moco::FusedBatchNormGraphBuilder::validate(), moco::MeanGraphBuilder::validate(), moco::PadGraphBuilder::validate(), moco::PlaceholderGraphBuilder::validate(), moco::ReshapeGraphBuilder::validate(), moco::ShapeGraphBuilder::validate(), moco::SoftmaxGraphBuilder::validate(), moco::SqueezeGraphBuilder::validate(), and moco::StopGradientGraphBuilder::validate().
bool plier::tf::parse_graphdef | ( | char const * | pbtxt, |
tensorflow::GraphDef & | graphdef | ||
) |
Definition at line 55 of file TestHelper.cpp.
bool plier::tf::parse_nodedef | ( | char const * | pbtxt, |
tensorflow::NodeDef & | nodedef | ||
) |
Definition at line 62 of file TestHelper.cpp.