ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Graph.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "loco/IR/Graph.h"
18
19#include <memory>
20#include <cassert>
21
22namespace
23{
24
25std::unique_ptr<loco::TensorShape> make_tensor_shape(std::initializer_list<loco::Dimension> dims)
26{
27 auto tensor_shape = std::make_unique<loco::TensorShape>();
28
29 tensor_shape->rank(dims.size());
30 {
31 uint32_t axis = 0;
32 for (auto it = dims.begin(); it != dims.end(); ++it)
33 {
34 tensor_shape->dim(axis++) = *it;
35 }
36 assert(axis == dims.size());
37 }
38
39 return tensor_shape;
40}
41
42} // namespace
43
44namespace loco
45{
46
47void Mixin<Trait::TensorShaped>::shape(std::initializer_list<Dimension> dims)
48{
49 shape(make_tensor_shape(dims));
50}
51
52GraphInput *Graph::InputContext::create(void) { return take(std::make_unique<GraphInput>(size())); }
53
55{
56 return take(std::make_unique<GraphOutput>(size()));
57}
58
59std::set<loco::Node *> all_nodes(loco::Graph *g)
60{
61 std::set<loco::Node *> res;
62
63 for (uint32_t n = 0; n < g->nodes()->size(); ++n)
64 {
65 res.insert(g->nodes()->at(n));
66 }
67
68 return res;
69}
70
71std::vector<Node *> input_nodes(const Graph *g)
72{
73 std::map<GraphInputIndex, loco::Node *> table;
74
75 for (uint32_t n = 0; n < g->nodes()->size(); ++n)
76 {
77 auto node = g->nodes()->at(n);
78
79 if (auto service = node->dialect()->service<GraphInputIndexQueryService>())
80 {
81 if (service->associated(node))
82 {
83 auto input_index = service->index(node);
84 assert(table.find(input_index) == table.end());
85 table[input_index] = node;
86 }
87 }
88 }
89
90 std::vector<loco::Node *> res;
91
92 for (uint32_t n = 0; n < g->inputs()->size(); ++n)
93 {
94 auto it = table.find(n);
95 res.emplace_back(it == table.end() ? nullptr : it->second);
96 }
97
98 return res;
99}
100
101std::vector<loco::Node *> output_nodes(loco::Graph *g)
102{
103 std::map<GraphOutputIndex, loco::Node *> table;
104
105 for (uint32_t n = 0; n < g->nodes()->size(); ++n)
106 {
107 auto node = g->nodes()->at(n);
108
109 if (auto service = node->dialect()->service<GraphOutputIndexQueryService>())
110 {
111 if (service->associated(node))
112 {
113 auto output_index = service->index(node);
114 assert(table.find(output_index) == table.end());
115 table[output_index] = node;
116 }
117 }
118 }
119
120 std::vector<loco::Node *> res;
121
122 for (uint32_t n = 0; n < g->outputs()->size(); ++n)
123 {
124 auto it = table.find(n);
125 res.emplace_back(it == table.end() ? nullptr : it->second);
126 }
127
128 return res;
129}
130
131std::unique_ptr<Graph> make_graph(void) { return std::unique_ptr<Graph>{new Graph}; }
132
133} // namespace loco
A neural network graph.
Definition Graph.h:161
Graph-level Input Metadata.
Definition Graph.h:107
Graph-level Output Metadata.
Definition Graph.h:135
U * take(std::unique_ptr< U > &&o)
Take the ownership of a given object and returns its raw pointer.
Definition ObjectPool.h:45
uint32_t size(void) const
Return the number of objects.
Definition ObjectPool.h:38
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
uint32_t rank(void) const
Definition TensorShape.h:35
std::set< Node * > all_nodes(Graph *)
Enumerate all the nodes in a given graph.
Definition Graph.cpp:59
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
std::unique_ptr< Graph > make_graph(void)
Definition Graph.cpp:131
int32_t size[5]
Definition Slice.cpp:35
GraphInput * create(void)
Definition Graph.cpp:52
GraphOutput * create(void)
Definition Graph.cpp:54