23#include <CustomOpInfo.pb.h>
25#include <google/protobuf/io/zero_copy_stream_impl.h>
26#include <google/protobuf/text_format.h>
34bool load_text(
const cwrap::Fildes &fildes, tf2tflite::CustomOpInfoDef &def)
36 google::protobuf::io::FileInputStream fis(fildes.
get());
38 return google::protobuf::TextFormat::Parse(&fis, &def);
45 int64_t rank64 = shape.dim_size();
46 assert(rank64 < std::numeric_limits<uint32_t>::max());
48 int32_t rank =
static_cast<int32_t
>(rank64);
51 for (int32_t d = 0; d < rank; d++)
53 int64_t dim_value = shape.
dim(d).size();
54 assert(dim_value >= 0ULL);
55 assert(dim_value < std::numeric_limits<uint32_t>::max());
57 uint32_t dim_value32 =
static_cast<uint32_t
>(dim_value);
58 to_shape.
dim(d) = dim_value32;
66 if (dtype == tf2tflite::DT_FLOAT)
67 return loco::DataType::FLOAT32;
68 else if (dtype == tf2tflite::DT_INT32)
69 return loco::DataType::S32;
71 throw std::runtime_error(
"Not yet supported datatype. Cannot convert.");
77loco::DataType get_dtype_attr(
const tf2tflite::CustomOpDef &custom_op)
79 std::string type_attr_name(
"dtype");
81 assert(custom_op.attr().count(type_attr_name) > 0);
82 const auto &attr = custom_op.attr().at(type_attr_name);
83 assert(attr.value_case() == tf2tflite::AttrValue::kType);
84 auto dtype_def = attr.type();
86 return convert_dtype(dtype_def);
91 std::string shape_attr_name(
"output_shape");
93 assert(custom_op.attr().count(shape_attr_name) > 0);
94 const auto &attr = custom_op.attr().at(shape_attr_name);
95 assert(attr.value_case() == tf2tflite::AttrValue::kShape);
96 auto shape_def = attr.shape();
98 return convert_shape(shape_def);
103 for (
const auto &custom_op : def.custom_op())
107 auto name = custom_op.name();
110 sig.
dtype(name, get_dtype_attr(custom_op));
126 if (fildes.get() < 0)
128 throw std::runtime_error{
"Error: " + path +
" not found"};
131 if (!load_text(fildes, def))
133 throw std::runtime_error{
"Error: Failed to parse prototxt " + path};
136 add_customop(def, sig);
uint32_t & dim(uint32_t axis)
Shape & resize(uint32_t size)
DataType
"scalar" value type
const tensorflow::TensorShapeProto & get_shape_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
void load_customop_conf(const std::string &path, moco::ModelSignature &sig)
Loads customop.conf into ModelSignature.
Class to store information to run a model. Normally this info comes from users via CLI params or conf...
void dtype(const std::string &node_name, loco::DataType dtype)
Adds node name and its dtype provided from user.
void add_customop(const std::string &op)
Adds customop op type (not name of node) provided from user.
void shape(const std::string &node_name, const angkor::TensorShape &shape)
Adds node name and its shape provided from user.