ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Frontend.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <moco/onnx/Frontend.h>
18
19#include "Convert.h"
20#include "GraphBuilderContext.h"
22#include "Onnxutil.h"
23
24#include <cwrap/Fildes.h>
25
26#include <onnx/onnx.pb.h>
27
28#include <google/protobuf/io/coded_stream.h>
29#include <google/protobuf/io/zero_copy_stream_impl.h>
30#include <google/protobuf/text_format.h>
31
32#include <sstream>
33#include <string>
34#include <stdexcept>
35
36#include <fcntl.h>
37#include <unistd.h>
38
39namespace
40{
41
42bool load_text(const cwrap::Fildes &fildes, onnx::ModelProto &model_proto)
43{
44 google::protobuf::io::FileInputStream fis(fildes.get());
45
46 return google::protobuf::TextFormat::Parse(&fis, &model_proto);
47}
48
49bool load_binary(const cwrap::Fildes &fildes, onnx::ModelProto &model_proto)
50{
51 google::protobuf::io::FileInputStream fis(fildes.get());
52 google::protobuf::io::CodedInputStream cis(&fis);
53
54 return model_proto.ParseFromCodedStream(&cis);
55}
56
57void load_onnx(const std::string &path, moco::onnx::Frontend::FileType type,
58 onnx::ModelProto &model_proto)
59{
60 cwrap::Fildes fildes{open(path.c_str(), O_RDONLY)};
61
62 if (fildes.get() < 0)
63 {
64 throw std::runtime_error{"Error: " + path + " not found"};
65 }
66
67 bool result = (type == moco::onnx::Frontend::FileType::Text) ? load_text(fildes, model_proto)
68 : load_binary(fildes, model_proto);
69
70 if (!result)
71 {
72 throw std::runtime_error{"Error: Failed to parse " + path};
73 }
74}
75
76// TODO Make comments clear
77void convert_graph(::onnx::ModelProto &onnx_model_proto, loco::Graph *graph)
78{
79 auto nodes = std::make_unique<moco::onnx::SymbolTable>();
80 auto input_names = std::make_unique<moco::onnx::SymbolTable>();
81
82 moco::onnx::GraphBuilderContext gb_context(graph, nodes.get(), input_names.get());
83
84 // Building a loco graph
85 // 1. Convert onnx::node to loco::Node
86 // 2. Convert onnx::initializer to loco::ConstGen node
87 // 3. Convert onnx::input to loco::Pull node
88 // 4. Connect inputs: set all node input(from a string) to actual node object
89 // 5. Set graph input
90 // 6. Create loco::Push node (with a proper input), and mark it as a graph output
91
92 assert(onnx_model_proto.has_graph());
93 ::onnx::GraphProto onnx_graph_proto = onnx_model_proto.graph();
94
98 assert(onnx_model_proto.opset_import_size() > 0);
99 int64_t opset_version = 1;
100 for (int i = 0; i < onnx_model_proto.opset_import_size(); ++i)
101 {
102 auto opset = onnx_model_proto.opset_import(i);
103
104 if (!opset.has_domain() || moco::onnx::is_default_domain(opset.domain()))
105 {
106 if (opset.version() > opset_version)
107 {
108 opset_version = opset.version();
109 }
110 }
111 else
112 {
113 throw std::runtime_error("Not supported for custom operation");
114 }
115 }
116
117 // 1. Convert all the nodes to loco::Node
118 for (const auto &n : onnx_graph_proto.node())
119 {
120 if (const auto *graph_builder = moco::onnx::GraphBuilderRegistry::get().lookup(n.op_type()))
121 {
122 if (!graph_builder->validate(opset_version, n))
123 {
124 throw std::runtime_error{"Invalid operator: " + n.op_type()};
125 }
126
127 graph_builder->build(opset_version, n, &gb_context);
128 }
129 else
130 {
131 throw std::runtime_error{"Not supported: " + n.op_type()};
132 }
133 }
134
135 // 2. Convert onnx::initializer to loco::ConstGen node
136 std::set<std::string> initializer_name_set;
137 for (int i = 0; i < onnx_graph_proto.initializer_size(); ++i)
138 {
139 auto initializer = onnx_graph_proto.initializer(i);
140
141 initializer_name_set.insert(initializer.name());
142
143 // TODO Support other data types
144 auto data = moco::onnx::get_float_data(initializer);
145
146 auto const_node = graph->nodes()->create<loco::ConstGen>();
147 const_node->dtype(moco::onnx::as_loco_datatype(initializer.data_type()));
148 const_node->rank(initializer.dims_size());
149 // TODO Support other data types
150 const_node->size<loco::DataType::FLOAT32>(data.size());
151
152 for (uint32_t i = 0; i < const_node->rank(); ++i)
153 {
154 const_node->dim(i) = initializer.dims(i);
155 }
156
157 for (uint32_t i = 0; i < data.size(); ++i)
158 {
159 // TODO Support other data types
160 const_node->at<loco::DataType::FLOAT32>(i) = data.at(i);
161 }
162
163 nodes->enroll(initializer.name(), const_node);
164 }
165
166 // 3. Convert onnx::input to loco::Pull node
167 for (int i = 0; i < onnx_graph_proto.input_size(); i++)
168 {
169 auto input = onnx_graph_proto.input(i);
170
171 // Already substituted by ConstGen node
172 if (initializer_name_set.find(input.name()) != initializer_name_set.end())
173 continue;
174
175 auto pull_node = graph->nodes()->create<loco::Pull>();
176
177 auto tensor = input.type().tensor_type();
179 pull_node->rank(tensor.shape().dim_size());
180 for (uint32_t i = 0; i < pull_node->rank(); ++i)
181 {
182 pull_node->dim(i) = (uint32_t)tensor.shape().dim(i).dim_value();
183 }
184
185 nodes->enroll(input.name(), pull_node);
186 }
187
188 // 4. Connect inputs: set all node input(from a string) to actual node object
189 loco::Graph::NodeContext *graph_nodes = graph->nodes();
190 uint32_t nodes_count = graph_nodes->size();
191 for (uint32_t n = 0; n < nodes_count; ++n)
192 {
193 loco::Node *node_to_set = graph_nodes->at(n);
194
195 unsigned int names_size = input_names->size(node_to_set);
196 assert(names_size == node_to_set->arity());
197 for (unsigned int i = 0; i < names_size; ++i)
198 {
199 auto input_name = input_names->name(node_to_set, i);
200 auto node = nodes->node(input_name);
201
202 // TODO use enum instead of dynamic_cast
203 loco::Forward *forward_node = dynamic_cast<loco::Forward *>(node_to_set);
204 if (forward_node != nullptr)
205 forward_node->input(node);
206 }
207 }
208
209 // 5. Set graph input
210 for (int i = 0; i < onnx_graph_proto.input_size(); i++)
211 {
212 auto input = onnx_graph_proto.input(i).name();
213
214 // Already substituted by ConstGen node
215 if (initializer_name_set.find(input) != initializer_name_set.end())
216 continue;
217
218 auto node = nodes->node(input);
219 assert(node != nullptr);
220
221 auto graph_input = graph->inputs()->create();
222
223 loco::Pull *pull_node = dynamic_cast<loco::Pull *>(node);
224 assert(pull_node != nullptr);
225
226 graph_input->name(input);
227 loco::link(graph_input, pull_node);
228 }
229
230 // 6. Create loco::Push node (with a proper input), and mark it as a graph output
231 for (int i = 0; i < onnx_graph_proto.output_size(); i++)
232 {
233 auto output = onnx_graph_proto.output(i).name();
234
235 auto output_node = nodes->node(output);
236 assert(output_node);
237
238 // create loco::Push for output of graph
239 auto push_node = graph->nodes()->create<loco::Push>();
240 push_node->from(output_node); // set input of Push to output node
241
242 // set the graph output name and node object
243 auto graph_output = graph->outputs()->create();
244 graph_output->name(output);
245 loco::link(graph_output, push_node);
246 }
247}
248
249} // namespace
250
251namespace moco
252{
253namespace onnx
254{
255
257{
258 // DO NOTHING
259}
260
261std::unique_ptr<loco::Graph> Frontend::load(const char *modelfile, FileType type) const
262{
263 ::onnx::ModelProto onnx_model_proto;
264
265 load_onnx(modelfile, type, onnx_model_proto);
266
267 auto graph = loco::make_graph();
268
269 convert_graph(onnx_model_proto, graph.get());
270
271 return std::move(graph);
272}
273
274} // namespace onnx
275} // namespace moco
enco::Bundle load(void) const override
Definition Frontend.cpp:40
POSIX File Descriptor.
Definition Fildes.h:29
int get(void) const
Definition Fildes.cpp:74
Create a value from constant byte array.
Definition Nodes.h:218
uint32_t size(void) const
Return the number of reserved elements.
Definition Nodes.cpp:185
Create a new value identical to its input.
Definition Nodes.h:146
Node * input(void) const
Definition Nodes.h:151
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:54
virtual uint32_t arity(void) const =0
Return the number of arguments.
T * at(uint32_t n) const
Access N-th object.
Definition ObjectPool.h:41
uint32_t size(void) const
Return the number of objects.
Definition ObjectPool.h:38
Create a value from user data.
Definition Nodes.h:96
void dtype(const DataType &d)
Definition Nodes.cpp:129
Make a value visible to user.
Definition Nodes.h:53
Node * from(void) const
Definition Nodes.h:58
Class to store context to build IR from onnx.
static GraphBuilderRegistry & get()
result
Definition infer.py:103
type
Definition infer.py:18
void link(GraphOutput *, Push *push)
Definition Nodes.cpp:65
Pull * pull_node(Graph *g, const GraphInputIndex &index)
Find a Pull node with a given input index.
Definition Nodes.cpp:162
std::unique_ptr< Graph > make_graph(void)
Definition Graph.cpp:131
Push * push_node(Graph *g, const GraphOutputIndex &index)
Find a Push node with a given output index.
Definition Nodes.cpp:67
CircleOutput * output_node(loco::Graph *g, const loco::GraphOutputIndex &index)
Find a CircleOutput node with a given output index.
std::vector< float > get_float_data(const ::onnx::TensorProto &tensor)
Get float tensor data.
Definition Onnxutil.cpp:48
loco::DataType as_loco_datatype(const int32_t tensor_dtype)
Definition Convert.cpp:28
bool is_default_domain(const std::string domain)
If domain is empty string or onnx.ai, it is default domain.
Definition Onnxutil.cpp:43
Definition Log.h:23