ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Graph.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __LOCO_IR_GRAPH_H__
18#define __LOCO_IR_GRAPH_H__
19
20#include "loco/IR/DataType.h"
21// TODO Include "Node.h" instead
22#include "loco/IR/Nodes.h"
23#include "loco/IR/NodePool.h"
26
27#include "loco/ADT/ObjectPool.h"
28
29#include <initializer_list>
30#include <set>
31#include <string>
32#include <memory>
33#include <vector>
34
35namespace loco
36{
37
38// TODO Introduce Named trait
39enum class Trait
40{
41 // Any "DataTyped" class has the following methods
42 // - DataType dtype(void) const;
43 // - void dtype(const DataType &value);
45 // Any "TensorShaped" class has the following methods
46 // - const TensorShape *shape(void) const;
47 // - void shape(std::unique_ptr<TensorShape> &&);
48 // - void shape(std::initializer_list<Dimension> &&);
49 //
50 // TODO Rename NodeMixin::TensorShape as NodeMixin::NDShape
52};
53
54template <Trait T> class Mixin;
55
56// TODO Re-implement NodeMixin<NodeTrait::DataType> using this mixin
57template <> class Mixin<Trait::DataTyped>
58{
59public:
60 Mixin() = default;
61
62public:
63 const DataType &dtype(void) const { return _dtype; }
64 void dtype(const DataType &value) { _dtype = value; }
65
66private:
67 DataType _dtype = DataType::Unknown;
68};
69
70template <> class Mixin<Trait::TensorShaped>
71{
72public:
73 Mixin() = default;
74
75public:
76 const TensorShape *shape(void) const { return _shape.get(); }
77 void shape(std::unique_ptr<TensorShape> &&shape) { _shape = std::move(shape); }
78 void shape(std::initializer_list<Dimension> dims);
79
80private:
81 std::unique_ptr<TensorShape> _shape = nullptr;
82};
83
88{
89public:
90 const std::string &name(void) const { return _name; }
91 void name(const std::string &name) { _name = name; }
92
95#define LOCO_NAMED_ENTITY_EXPOSE using NamedEntity::name
96
97private:
98 std::string _name;
99};
100
104class GraphInput final : private NamedEntity,
105 public Mixin<Trait::DataTyped>,
106 public Mixin<Trait::TensorShaped>
107{
108public:
110
111 // TODO Use GraphInputIndex (instead of uint32_t)
112 GraphInput(uint32_t index) : _index{index}
113 {
114 // DO NOTHING
115 }
116
117 GraphInput(const GraphInput &) = delete;
118 GraphInput(GraphInput &&) = delete;
119
120 ~GraphInput() = default;
121
122public:
123 GraphInputIndex index(void) const { return _index; }
124
125private:
126 uint32_t _index;
127};
128
132class GraphOutput final : private NamedEntity,
133 public Mixin<Trait::DataTyped>,
134 public Mixin<Trait::TensorShaped>
135{
136public:
138
139 // TODO Use GraphOutputIndex (instead of uint32_t)
140 GraphOutput(uint32_t index) : _index{index}
141 {
142 // DO NOTHING
143 }
144
145 GraphOutput(const GraphOutput &) = delete;
147
148 ~GraphOutput() = default;
149
150public:
151 GraphOutputIndex index(void) const { return _index; }
152
153private:
154 uint32_t _index;
155};
156
160class Graph final : public NamedEntity
161{
162public:
171
177 template <typename T> struct SimpleFactoryObjectPool : public ObjectPool<T>
178 {
179 virtual ~SimpleFactoryObjectPool() = default;
180
181 T *create(void)
182 {
183 std::unique_ptr<T> ptr{new T};
184 return ObjectPool<T>::take(std::move(ptr));
185 }
186 };
187
191 struct InputContext final : public ObjectPool<GraphInput>
192 {
193 GraphInput *create(void);
194 };
195
199 struct OutputContext final : public ObjectPool<GraphOutput>
200 {
201 GraphOutput *create(void);
202 };
203
204public:
206 {
207 // Associate "NodeContext" and the current "Graph"
208 _node_ctx.graph(this);
209 }
210
211 // Copy/Move is not allowed for Graph
212 Graph(const Graph &) = delete;
213 Graph(Graph &&) = delete;
214
215 ~Graph() = default;
216
217public:
218 NodeContext *nodes(void) { return &_node_ctx; }
219 const NodeContext *nodes(void) const { return &_node_ctx; }
220 InputContext *inputs(void) { return &_input_ctx; }
221 const InputContext *inputs(void) const { return &_input_ctx; }
222 OutputContext *outputs(void) { return &_output_ctx; }
223 const OutputContext *outputs(void) const { return &_output_ctx; }
224
225private:
226 NodeContext _node_ctx;
227 InputContext _input_ctx;
228 OutputContext _output_ctx;
229};
230
232{
233 virtual ~GraphInputIndexQueryService() = default;
234
238 virtual bool associated(const Node *node) const = 0;
239
245 virtual GraphInputIndex index(const Node *node) const = 0;
246};
247
248std::vector<Node *> input_nodes(const Graph *);
249
251{
252 virtual ~GraphOutputIndexQueryService() = default;
253
257 virtual bool associated(const Node *node) const = 0;
258
264 virtual GraphOutputIndex index(const Node *node) const = 0;
265};
266
267std::vector<Node *> output_nodes(Graph *);
268
277std::set<Node *> all_nodes(Graph *);
278
279std::unique_ptr<Graph> make_graph(void);
280
281} // namespace loco
282
283#endif // __LOCO_IR_GRAPH_H__
A neural network graph.
Definition Graph.h:161
InputContext * inputs(void)
Definition Graph.h:220
NodeContext * nodes(void)
Definition Graph.h:218
const InputContext * inputs(void) const
Definition Graph.h:221
const NodeContext * nodes(void) const
Definition Graph.h:219
NodePool NodeContext
Node Pool.
Definition Graph.h:170
Graph(Graph &&)=delete
Graph(const Graph &)=delete
const OutputContext * outputs(void) const
Definition Graph.h:223
OutputContext * outputs(void)
Definition Graph.h:222
~Graph()=default
Graph-level Input Metadata.
Definition Graph.h:107
GraphInput(uint32_t index)
Definition Graph.h:112
GraphInput(const GraphInput &)=delete
GraphInputIndex index(void) const
Definition Graph.h:123
GraphInput(GraphInput &&)=delete
~GraphInput()=default
Graph-level Output Metadata.
Definition Graph.h:135
~GraphOutput()=default
GraphOutputIndex index(void) const
Definition Graph.h:151
GraphOutput(GraphOutput &&)=delete
GraphOutput(const GraphOutput &)=delete
GraphOutput(uint32_t index)
Definition Graph.h:140
void dtype(const DataType &value)
Definition Graph.h:64
const DataType & dtype(void) const
Definition Graph.h:63
const TensorShape * shape(void) const
Definition Graph.h:76
void shape(std::unique_ptr< TensorShape > &&shape)
Definition Graph.h:77
Trait for elements with name.
Definition Graph.h:88
void name(const std::string &name)
Definition Graph.h:91
const std::string & name(void) const
Definition Graph.h:90
Logical unit of computation.
Definition Node.h:54
Object Pool.
Definition ObjectPool.h:32
U * take(std::unique_ptr< U > &&o)
Take the ownership of a given object and returns its raw pointer.
Definition ObjectPool.h:45
uint32_t GraphInputIndex
std::set< Node * > all_nodes(Graph *)
Enumerate all the nodes in a given graph.
Definition Graph.cpp:59
uint32_t GraphOutputIndex
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
DataType
"scalar" value type
Definition DataType.h:27
Trait
Definition Graph.h:40
std::unique_ptr< Graph > make_graph(void)
Definition Graph.cpp:131
Dialect Service interface.
GraphInput Pool.
Definition Graph.h:192
GraphInput * create(void)
Definition Graph.cpp:52
GraphOutput Pool.
Definition Graph.h:200
GraphOutput * create(void)
Definition Graph.cpp:54
Object Pool with Simple Factory Method.
Definition Graph.h:178
virtual ~SimpleFactoryObjectPool()=default
virtual GraphInputIndex index(const Node *node) const =0
virtual bool associated(const Node *node) const =0
Check whether a given node is associated with any Graph-level input.
virtual ~GraphInputIndexQueryService()=default
virtual GraphOutputIndex index(const Node *node) const =0
virtual ~GraphOutputIndexQueryService()=default
virtual bool associated(const Node *node) const =0
Check whether a given node is associated with any Graph-level output.