ONE - On-device Neural Engine
Loading...
Searching...
No Matches
TestGraph.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __TEST_GRAPH_H__
18#define __TEST_GRAPH_H__
19
20#include <luci/IR/CircleNodes.h>
21
22#include <loco.h>
23
24#include <cassert>
25#include <memory>
26
27// TODO Change all Canonical nodes to Circle nodes
28
29namespace luci
30{
31namespace test
32{
33
35{
36public:
37 std::unique_ptr<loco::Graph> g;
40
41 TestGraph() // creates Pull and Push
42 {
44
45 input_node = g->nodes()->create<luci::CircleInput>();
46
47 output_node = g->nodes()->create<luci::CircleOutput>();
48
49 auto input = g->inputs()->create();
50 {
51 input->name("input");
52 luci::link(input, input_node);
53 }
54 auto output = g->outputs()->create();
55 {
56 output->name("output");
57 luci::link(output, output_node);
58 }
59
60 _next_input = input_node;
61 }
62
63 loco::Graph *graph() { return g.get(); }
64
66 template <class T> T *append()
67 {
68 auto node = g->nodes()->create<T>();
69 _next_input = node;
70
71 return node;
72 }
73
75 template <class T> T *append(luci::CircleNode *arg1)
76 {
77 auto node = g->nodes()->create<T>();
78 setInput(node, arg1);
79 _next_input = node;
80
81 return node;
82 }
83
85 template <class T> T *append(luci::CircleNode *arg1, luci::CircleNode *arg2)
86 {
87 auto node = g->nodes()->create<T>();
88 setInput(node, arg1, arg2);
89 _next_input = node;
90
91 return node;
92 }
93
95 template <class T>
97 {
98 auto node = g->nodes()->create<T>();
99 setInput(node, arg1, arg2, arg3);
100 _next_input = node;
101
102 return node;
103 }
104
105 // output will get the last appended node
106 void complete() { output_node->from(_next_input); }
107
108 void complete(luci::CircleNode *last_node) { output_node->from(last_node); }
109
110private:
111 // arity 1
112 void setInput(luci::CircleNode *, luci::CircleNode *) { assert(false && "NYI"); };
113
114 void setInput(luci::CircleAveragePool2D *node, luci::CircleNode *input) { node->value(input); };
115 void setInput(luci::CircleRelu *node, luci::CircleNode *input) { node->features(input); };
116 void setInput(luci::CircleSqueeze *node, luci::CircleNode *input) { node->input(input); };
117
118 void setInput(luci::CircleGatherNd *node, luci::CircleNode *params, luci::CircleNode *indices)
119 {
120 node->params(params);
121 node->indices(indices);
122 };
123
124 // arity 2
126 {
127 assert(false && "NYI");
128 };
129
130 void setInput(luci::CircleExpandDims *node, luci::CircleNode *arg1, luci::CircleNode *arg2)
131 {
132 node->input(arg1);
133 node->axis(arg2);
134 };
135
136 void setInput(luci::CircleTranspose *node, luci::CircleNode *arg1, luci::CircleNode *arg2)
137 {
138 node->a(arg1);
139 node->perm(arg2);
140 };
141
143 {
144 node->input(input);
145 node->size(size);
146 };
147
148 void setInput(luci::CircleResizeNearestNeighbor *node, luci::CircleNode *input,
150 {
151 node->input(input);
152 node->size(size);
153 };
154
155 // arity 3
157 {
158 assert(false && "NYI");
159 };
160
161private:
162 loco::Node *_next_input;
163};
164
166{
168};
169
170template <ExampleGraphType T> class ExampleGraph;
171
178{
179public:
180 luci::CircleConst *const_perm = nullptr;
181 luci::CircleTranspose *transpose_node = nullptr;
182
183public:
185 {
186 const_perm = append<luci::CircleConst>();
187 transpose_node = append<luci::CircleTranspose>(input_node, const_perm);
188 complete(transpose_node);
189 }
190};
191
192} // namespace test
193} // namespace luci
194
195namespace luci
196{
197namespace test
198{
199
202
205
208
211
212} // namespace test
213} // namespace luci
214
215#endif // __TEST_GRAPH_H__
A neural network graph.
Definition Graph.h:161
Logical unit of computation.
Definition Node.h:54
AVERAGE_POOL_2D in Circle.
loco::Node * value(void) const
Class to build tensor data.
Definition CircleConst.h:35
EXPAND_DIMS in Circle.
loco::Node * axis(void) const
loco::Node * input(void) const
GATHER_ND in Circle.
loco::Node * params(void) const
loco::Node * indices(void) const
CircleNode used for Input of the Graph.
Definition CircleInput.h:36
CircleNode for Output of the Graph.
loco::Node * from(void) const
RELU in Circle.
Definition CircleRelu.h:32
loco::Node * features(void) const
Definition CircleRelu.h:34
RESIZE_BILINEAR in Circle.
loco::Node * size(void) const
loco::Node * input(void) const
RESIZE_NEAREST_NEIGHBOR in Circle.
SQUEEZE in Circle.
loco::Node * input(void) const
TRANSPOSE in Circle.
loco::Node * a(void) const
loco::Node * perm(void) const
T * append(luci::CircleNode *arg1)
Creates op T (arity=1) with arg1 as an input and appends it to graph.
Definition TestGraph.h:75
loco::Graph * graph()
Definition TestGraph.h:63
T * append(luci::CircleNode *arg1, luci::CircleNode *arg2, luci::CircleNode *arg3)
Creates op T (arity=3) with arg1, arg2, arg3 as inputs and appends it to graph.
Definition TestGraph.h:96
luci::CircleOutput * output_node
Definition TestGraph.h:39
luci::CircleInput * input_node
Definition TestGraph.h:38
T * append()
Creates node with NO arg and appends it to graph.
Definition TestGraph.h:66
T * append(luci::CircleNode *arg1, luci::CircleNode *arg2)
Creates op T (arity=2) with arg1, arg2 as inputs and appends it to graph.
Definition TestGraph.h:85
std::unique_ptr< loco::Graph > g
Definition TestGraph.h:37
void complete(luci::CircleNode *last_node)
Definition TestGraph.h:108
std::unique_ptr< Graph > make_graph(void)
Definition Graph.cpp:131
void graph_output_shape(luci::CircleOutput *output)
This will set GraphOutput shape from CircleOutput shape.
void graph_input_shape(luci::CircleInput *input)
This will set GraphInput shape from CircleInput shape.
void graph_input_dtype(luci::CircleInput *input)
This will set GraphInput dtype from CircleInput dtype.
void graph_output_dtype(luci::CircleOutput *output)
This will set GraphOutput dtype from CircleOutput dtype.
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
void link(loco::GraphOutput *, CircleOutput *)
Link GraphOutput with CircleOutput node.
int32_t size[5]
Definition Slice.cpp:35