26#include <onnx/onnx.pb.h>
28#include <google/protobuf/io/coded_stream.h>
29#include <google/protobuf/io/zero_copy_stream_impl.h>
30#include <google/protobuf/text_format.h>
42bool load_text(
const cwrap::Fildes &fildes, onnx::ModelProto &model_proto)
44 google::protobuf::io::FileInputStream fis(fildes.
get());
46 return google::protobuf::TextFormat::Parse(&fis, &model_proto);
49bool load_binary(
const cwrap::Fildes &fildes, onnx::ModelProto &model_proto)
51 google::protobuf::io::FileInputStream fis(fildes.
get());
52 google::protobuf::io::CodedInputStream cis(&fis);
58 onnx::ModelProto &model_proto)
64 throw std::runtime_error{
"Error: " + path +
" not found"};
68 : load_binary(fildes, model_proto);
72 throw std::runtime_error{
"Error: Failed to parse " + path};
77void convert_graph(::onnx::ModelProto &onnx_model_proto,
loco::Graph *graph)
79 auto nodes = std::make_unique<moco::onnx::SymbolTable>();
80 auto input_names = std::make_unique<moco::onnx::SymbolTable>();
92 assert(onnx_model_proto.has_graph());
93 ::onnx::GraphProto onnx_graph_proto = onnx_model_proto.graph();
98 assert(onnx_model_proto.opset_import_size() > 0);
99 int64_t opset_version = 1;
100 for (
int i = 0; i < onnx_model_proto.opset_import_size(); ++i)
102 auto opset = onnx_model_proto.opset_import(i);
106 if (opset.version() > opset_version)
108 opset_version = opset.version();
113 throw std::runtime_error(
"Not supported for custom operation");
118 for (
const auto &n : onnx_graph_proto.node())
122 if (!graph_builder->validate(opset_version, n))
124 throw std::runtime_error{
"Invalid operator: " + n.op_type()};
127 graph_builder->build(opset_version, n, &gb_context);
131 throw std::runtime_error{
"Not supported: " + n.op_type()};
136 std::set<std::string> initializer_name_set;
137 for (
int i = 0; i < onnx_graph_proto.initializer_size(); ++i)
139 auto initializer = onnx_graph_proto.initializer(i);
141 initializer_name_set.insert(initializer.name());
148 const_node->rank(initializer.dims_size());
150 const_node->
size<loco::DataType::FLOAT32>(
data.size());
152 for (uint32_t i = 0; i < const_node->rank(); ++i)
154 const_node->dim(i) = initializer.dims(i);
157 for (uint32_t i = 0; i <
data.size(); ++i)
160 const_node->at<loco::DataType::FLOAT32>(i) =
data.at(i);
163 nodes->enroll(initializer.name(), const_node);
167 for (
int i = 0; i < onnx_graph_proto.input_size(); i++)
169 auto input = onnx_graph_proto.input(i);
172 if (initializer_name_set.find(
input.name()) != initializer_name_set.end())
180 for (uint32_t i = 0; i <
pull_node->rank(); ++i)
190 uint32_t nodes_count = graph_nodes->
size();
191 for (uint32_t n = 0; n < nodes_count; ++n)
195 unsigned int names_size = input_names->size(node_to_set);
196 assert(names_size == node_to_set->
arity());
197 for (
unsigned int i = 0; i < names_size; ++i)
199 auto input_name = input_names->name(node_to_set, i);
200 auto node = nodes->node(input_name);
204 if (forward_node !=
nullptr)
205 forward_node->
input(node);
210 for (
int i = 0; i < onnx_graph_proto.input_size(); i++)
212 auto input = onnx_graph_proto.input(i).name();
215 if (initializer_name_set.find(input) != initializer_name_set.end())
218 auto node = nodes->node(input);
219 assert(node !=
nullptr);
221 auto graph_input =
graph->inputs()->create();
224 assert(pull_node !=
nullptr);
226 graph_input->name(input);
231 for (
int i = 0; i < onnx_graph_proto.output_size(); i++)
233 auto output = onnx_graph_proto.output(i).name();
243 auto graph_output =
graph->outputs()->create();
244 graph_output->name(output);
263 ::onnx::ModelProto onnx_model_proto;
265 load_onnx(modelfile, type, onnx_model_proto);
269 convert_graph(onnx_model_proto, graph.get());
271 return std::move(graph);
enco::Bundle load(void) const override
Create a value from constant byte array.
uint32_t size(void) const
Return the number of reserved elements.
Create a new value identical to its input.
Logical unit of computation.
virtual uint32_t arity(void) const =0
Return the number of arguments.
T * at(uint32_t n) const
Access N-th object.
uint32_t size(void) const
Return the number of objects.
Create a value from user data.
void dtype(const DataType &d)
Make a value visible to user.
Class to store context to build IR from onnx.
static GraphBuilderRegistry & get()
void link(GraphOutput *, Push *push)
Pull * pull_node(Graph *g, const GraphInputIndex &index)
Find a Pull node with a given input index.
std::unique_ptr< Graph > make_graph(void)
Push * push_node(Graph *g, const GraphOutputIndex &index)
Find a Push node with a given output index.
CircleOutput * output_node(loco::Graph *g, const loco::GraphOutputIndex &index)
Find a CircleOutput node with a given output index.
std::vector< float > get_float_data(const ::onnx::TensorProto &tensor)
Get float tensor data.
loco::DataType as_loco_datatype(const int32_t tensor_dtype)
bool is_default_domain(const std::string domain)
If domain is empty string or onnx.ai, it is default domain.