20#include <tensorflow/core/framework/graph.pb.h>
31std::unique_ptr<T> open_fstream(
const std::string &path, std::ios_base::openmode mode)
38 auto stream = std::make_unique<T>(path.c_str(), mode);
39 if (!stream->is_open())
41 throw std::runtime_error{
"Failed to open " + path};
53bool HasAttr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
55 return node.attr().count(attr_name) > 0;
58tensorflow::DataType
GetDataTypeAttr(
const tensorflow::NodeDef &node,
const std::string &attr_name)
60 assert(
HasAttr(node, attr_name));
61 const auto &attr = node.attr().at(attr_name);
62 assert(attr.value_case() == tensorflow::AttrValue::kType);
66tensorflow::TensorProto *
GetTensorAttr(tensorflow::NodeDef &node,
const std::string &attr_name)
68 assert(
HasAttr(node, attr_name));
69 tensorflow::AttrValue &attr = node.mutable_attr()->at(attr_name);
70 assert(attr.value_case() == tensorflow::AttrValue::kTensor);
71 return attr.mutable_tensor();
78 for (
auto &d : shape.dim())
98 throw std::runtime_error(
"Argument index out of bound");
100 return std::string(_argv[index]);
108 return std::string(_argv[index]);
113 auto iocfg = std::make_unique<IOConfiguration>();
115 auto in = open_fstream<std::ifstream>(cmdargs.
get_or(0,
"-"), std::ios::in | std::ios::binary);
116 iocfg->in(std::move(in));
118 auto out = open_fstream<std::ofstream>(cmdargs.
get_or(1,
"-"), std::ios::out | std::ios::binary);
119 iocfg->out(std::move(out));
std::string get_or(unsigned int index, const std::string &) const
std::string get(unsigned int index) const
bool HasAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
int GetElementCount(const tensorflow::TensorShapeProto &shape)
GetElementCount returns -1 for rank-0 tensor shape.
tensorflow::TensorProto * GetTensorAttr(tensorflow::NodeDef &node, const std::string &attr_name)
tensorflow::DataType GetDataTypeAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
std::unique_ptr< IOConfiguration > make_ioconfig(const CmdArguments &cmdargs)