ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FormattedGraph.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
20
23
24#include <pp/Format.h>
25
26#include <memory>
27#include <map>
28#include <set>
29
30#include <cassert>
31
33
34namespace
35{
36
37std::string str(const loco::DataType &dtype)
38{
39 switch (dtype)
40 {
41 case loco::DataType::Unknown:
42 return "Unknown";
43
44 case loco::DataType::U4:
45 return "U4";
46 case loco::DataType::U8:
47 return "U8";
48 case loco::DataType::U16:
49 return "U16";
50 case loco::DataType::U32:
51 return "U32";
52 case loco::DataType::U64:
53 return "U64";
54
55 case loco::DataType::S4:
56 return "S4";
57 case loco::DataType::S8:
58 return "S8";
59 case loco::DataType::S16:
60 return "S16";
61 case loco::DataType::S32:
62 return "S32";
63 case loco::DataType::S64:
64 return "S64";
65
66 case loco::DataType::FLOAT16:
67 return "FLOAT16";
68 case loco::DataType::FLOAT32:
69 return "FLOAT32";
70 case loco::DataType::FLOAT64:
71 return "FLOAT64";
72
73 case loco::DataType::BOOL:
74 return "BOOL";
75
76 default:
77 break;
78 };
79
80 throw std::invalid_argument{"dtype"};
81}
82
83std::string str(const loco::Domain &domain)
84{
85 // TODO Generate!
86 switch (domain)
87 {
89 return "Unknown";
91 return "Tensor";
93 return "Feature";
95 return "Filter";
97 return "DWFilter";
99 return "Bias";
100 default:
101 break;
102 }
103
104 throw std::invalid_argument{"domain"};
105}
106
107std::string str(const loco::NodeShape &node_shape)
108{
109 using namespace locop;
110
111 switch (node_shape.domain())
112 {
114 {
115 auto tensor_shape = node_shape.as<loco::TensorShape>();
116 return pp::fmt(locop::fmt<TensorShapeFormat::Plain>(&tensor_shape));
117 }
118 // TODO Show details
123 return "...";
124
125 default:
126 break;
127 }
128
129 throw std::invalid_argument{"domain"};
130}
131
132// TODO Use locop::fmt<TensorShapeFormat ...>
134formatted_tensor_shape(const loco::TensorShape *ptr)
135{
137}
138
139} // namespace
140
141namespace
142{
143
144struct NodeDesc : public locop::NodeDesc
145{
146public:
147 NodeDesc() = default;
148 NodeDesc(const locop::OpName &opname) : locop::NodeDesc{opname}
149 {
150 // DO NOTHING
151 }
152
153public:
154 // DEPRECATED
155 const locop::OpName &name(void) const { return opname(); }
156
157 // DEPRECATED
158 uint32_t arg_size(void) const { return args().count(); }
159 // DEPRECATED
160 const locop::ArgElem &arg(uint32_t n) const { return args().at(n); }
161 // DEPRECATED
162 void arg(const locop::ArgName &name, const locop::ArgValue &value) { args().append(name, value); }
163};
164
165} // namespace
166
167// TODO Remove this workaround
168namespace locop
169{
170
171std::ostream &operator<<(std::ostream &os, const NodeDesc &d)
172{
173 assert(d.state() != NodeDesc::State::Invalid);
174
175 std::vector<std::string> values;
176
177 for (uint32_t n = 0; n < d.args().count(); ++n)
178 {
179 values.emplace_back(d.args().at(n).first + ": " + d.args().at(n).second);
180 }
181
182 if (d.state() == NodeDesc::State::PartiallyKnown)
183 {
184 values.emplace_back("...");
185 }
186
187 os << d.opname();
188 os << "(";
189 if (values.size() > 0)
190 {
191 os << values.at(0);
192 for (uint32_t n = 1; n < values.size(); ++n)
193 {
194 os << ", " << values.at(n);
195 }
196 }
197 os << ")";
198
199 return os;
200}
201
202} // namespace locop
203
204namespace locop
205{
206
207std::ostream &operator<<(std::ostream &os, const FormattedGraph &fmt)
208{
209 fmt.dump(os);
210 return os;
211}
212
213} // namespace locop
214
215namespace locop
216{
217
219{
220 struct SymbolTableImpl final : public SymbolTable
221 {
222 std::string lookup(const loco::Node *node) const final
223 {
224 if (node == nullptr)
225 {
226 return "(null)";
227 }
228
229 return _content.at(node);
230 }
231
232 std::map<const loco::Node *, std::string> _content;
233 };
234
235 SymbolTableImpl symbols;
236
237 auto symbol = [&symbols](const loco::Node *node) { return symbols.lookup(node); };
238
239 for (uint32_t n = 0; n < _graph->nodes()->size(); ++n)
240 {
241 symbols._content[_graph->nodes()->at(n)] = pp::fmt("%", n);
242 }
243
244 // Find the disjoint node clusters
245 //
246 // TODO Move this implementation into loco Algorithms.h
247 std::map<loco::Node *, loco::Node *> parents;
248
249 for (auto node : loco::all_nodes(_graph))
250 {
251 parents[node] = nullptr;
252 }
253
254 for (auto node : loco::all_nodes(_graph))
255 {
256 for (uint32_t n = 0; n < node->arity(); ++n)
257 {
258 if (auto arg = node->arg(n))
259 {
260 parents[arg] = node;
261 }
262 }
263 }
264
265 auto find = [&parents](loco::Node *node) {
266 loco::Node *cur = node;
267
268 while (parents.at(cur) != nullptr)
269 {
270 cur = parents.at(cur);
271 }
272
273 return cur;
274 };
275
276 std::set<loco::Node *> roots;
277
278 for (auto node : loco::all_nodes(_graph))
279 {
280 roots.insert(find(node));
281 }
282
283 std::map<loco::Node *, std::set<loco::Node *>> clusters;
284
285 // Create clusters
286 for (auto root : roots)
287 {
288 clusters[root] = std::set<loco::Node *>{};
289 }
290
291 for (auto node : loco::all_nodes(_graph))
292 {
293 clusters.at(find(node)).insert(node);
294 }
295
296 std::unique_ptr<locop::NodeSummaryBuilder> node_summary_builder;
297
298 if (_factory)
299 {
300 // Use User-defined NodeSummaryBuilder if NodeSummaryBuilderFactory is present
301 node_summary_builder = _factory->create(&symbols);
302 }
303 else
304 {
305 // Use Built-in NodeSummaryBuilder otherwise
306 node_summary_builder = std::make_unique<GenericNodeSummaryBuilder>(&symbols);
307 }
308
309 // Print Graph Input(s)
310 for (uint32_t n = 0; n < _graph->inputs()->size(); ++n)
311 {
312 auto input = _graph->inputs()->at(n);
313
314 std::string name = input->name();
315
316 std::string shape = "?";
317 if (input->shape() != nullptr)
318 {
319 shape = pp::fmt(formatted_tensor_shape(input->shape()));
320 }
321
322 // TODO Print dtype
323 os << pp::fmt("In #", n, " { name: ", name, ", shape: ", shape, " }") << std::endl;
324 }
325
326 // Print Graph Output(s)
327 for (uint32_t n = 0; n < _graph->outputs()->size(); ++n)
328 {
329 auto output = _graph->outputs()->at(n);
330
331 std::string name = output->name();
332
333 std::string shape = "?";
334 if (output->shape() != nullptr)
335 {
336 shape = pp::fmt(formatted_tensor_shape(output->shape()));
337 }
338
339 // TODO Print dtype
340 os << pp::fmt("Out #", n, " { name: ", name, ", shape: ", shape, " }") << std::endl;
341 }
342
343 if (_graph->inputs()->size() + _graph->outputs()->size() != 0)
344 {
345 os << std::endl;
346 }
347
348 for (auto it = clusters.begin(); it != clusters.end(); ++it)
349 {
350 std::vector<loco::Node *> cluster_outputs;
351
352 for (auto node : it->second)
353 {
354 // NOTE This is inefficient but anyway working :)
355 if (loco::succs(node).empty())
356 {
357 cluster_outputs.emplace_back(node);
358 }
359 }
360
361 for (auto node : loco::postorder_traversal(cluster_outputs))
362 {
363 locop::NodeSummary node_summary;
364
365 // Build a node summary
366 if (!node_summary_builder->build(node, node_summary))
367 {
368 throw std::runtime_error{"Fail to build a node summary"};
369 }
370
371 for (uint32_t n = 0; n < node_summary.comments().count(); ++n)
372 {
373 os << "; " << node_summary.comments().at(n) << std::endl;
374 }
375
376 os << symbol(node);
377
378 if (loco::shape_known(node))
379 {
380 auto node_shape = loco::shape_get(node);
381 os << " : " << str(node_shape.domain());
382 os << "<";
383 os << str(node_shape);
384 os << ", ";
385 // Show DataType
386 os << (loco::dtype_known(node) ? str(loco::dtype_get(node)) : std::string{"?"});
387 os << ">";
388 }
389
390 os << " = " << node_summary << std::endl;
391 }
392 os << std::endl;
393 }
394}
395
396} // namespace locop
Logical unit of computation.
Definition Node.h:54
ShapeType as(void) const
const Domain & domain(void) const
Definition NodeShape.h:48
uint32_t count(void) const
The number of presented arguments.
Definition NodeSummary.h:40
const ArgElem & at(uint32_t n) const
Definition NodeSummary.h:42
void append(const ArgName &name, const ArgValue &value)
Definition NodeSummary.h:43
uint32_t count(void) const
Definition NodeSummary.h:61
const std::string & at(uint32_t n) const
Definition NodeSummary.h:62
str
Definition infer.py:18
std::vector< loco::Node * > postorder_traversal(const std::vector< loco::Node * > &roots)
Generate postorder traversal sequence starting from "roots".
Definition Algorithm.cpp:53
std::set< Node * > all_nodes(Graph *)
Enumerate all the nodes in a given graph.
Definition Graph.cpp:59
std::set< Node * > succs(const Node *node)
Enumerate all the successors of a given node.
Definition Node.cpp:46
bool shape_known(const Node *node)
bool dtype_known(const Node *node)
DataType
"scalar" value type
Definition DataType.h:27
NodeShape shape_get(const Node *node)
DataType dtype_get(const Node *node)
Domain
Describe the kind of (N-dimensional) loco values.
Definition Domain.h:40
FormattedGraphImpl< F > fmt(loco::Graph *g)
std::ostream & operator<<(std::ostream &, const FormattedGraph &)
std::pair< ArgName, ArgValue > ArgElem
Definition NodeSummary.h:31
std::string ArgValue
Definition NodeSummary.h:30
std::string OpName
Definition NodeSummary.h:28
std::string ArgName
Definition NodeSummary.h:29
loco::NodeShape node_shape(const loco::Node *node)
int32_t size[5]
Definition Slice.cpp:35
NodeDesc()=default
const OpName & opname(void) const
const ArgDesc & args(void) const
Definition NodeSummary.h:94
const Comments & comments(void) const
Definition NodeSummary.h:97
Symbol Table Interface.
Definition SymbolTable.h:33