18#ifndef __PLIER_TF_CONVERT_H__
19#define __PLIER_TF_CONVERT_H__
24#include <tensorflow/core/framework/graph.pb.h>
33bool has_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name);
34bool has_attrs(
const tensorflow::NodeDef &node,
const std::vector<std::string> &attr_names);
37 const std::string &attr_name);
38const tensorflow::TensorShapeProto &
get_shape_attr(
const tensorflow::NodeDef &node,
39 const std::string &attr_name);
40const tensorflow::TensorProto &
get_tensor_attr(
const tensorflow::NodeDef &node,
41 const std::string &attr_name);
42const tensorflow::AttrValue_ListValue &
get_list_attr(
const tensorflow::NodeDef &node,
43 const std::string &attr_name);
44const std::string &
get_string_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name);
45int64_t
get_int_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name);
46float get_float_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name);
47bool get_bool_attr(
const tensorflow::NodeDef &node,
const std::string &attr_name);
49std::vector<int64_t>
as_int64_list(
const tensorflow::AttrValue_ListValue &lv);
72void copy_shape(
const tensorflow::TensorShapeProto &tf_shape,
DataType
"scalar" value type
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
const tensorflow::TensorProto & get_tensor_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
tensorflow::DataType get_datatype_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
const tensorflow::TensorShapeProto & get_shape_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
DataLayout
Class to represent TensorFlow "data_format" attr.
DataLayout as_data_layout(const std::string &tf_layout_str)
@ brief Convert TF Data Layout string (e.g., "NHWC") to enum class for programming convenience
void copy_shape(const tensorflow::TensorShapeProto &tf_shape, nncc::core::ADT::tensor::Shape &to_shape)
Copy shape defined in TensorShapeProto to angkor shape.
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)