ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Importer.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "moco/Importer.h"
20
22
24#include <moco/IR/TFNode.h>
25
26#include <oops/UserExn.h>
27
28#include <memory>
29#include <cassert>
30#include <sstream>
31#include <stdexcept>
32
33namespace
34{
35
36void convert_graph(const moco::GraphBuilderSource &source, const moco::ModelSignature &signature,
37 tensorflow::GraphDef &tf_graph_def, loco::Graph *graph)
38{
39 auto nodedef = std::make_unique<moco::NodeDefTable>();
40 auto tensor_names = std::make_unique<moco::SymbolTable>();
41 auto updates = std::make_unique<moco::UpdateQueue>();
42
43 moco::GraphBuilderContext gb_context(graph, nodedef.get(), tensor_names.get(), updates.get());
44
45 // Building a loco graph
46 // 1. Convert all the nodes to moco::TFNode
47 // 2. Connect inputs: set all node input(from a string) to actual node object
48 // 3. Set graph input
49 // 4. Create moco::TFPush node and set graph output
50
54 for (const auto &n : tf_graph_def.node())
55 {
56 nodedef->enroll(n.name(), &n);
57 }
58
86 for (const auto &n : tf_graph_def.node())
87 {
88 if (const auto *graph_builder = source.lookup(n.op()))
89 {
90 if (!graph_builder->validate(n))
91 {
92 throw oops::UserExn("Invalid operator", n.op());
93 }
94
95 graph_builder->build(n, &gb_context);
96 }
97 else
98 {
99 throw oops::UserExn("Not supported", n.op());
100 }
101 }
102
114 for (auto &update : updates->queue())
115 {
116 update->input(tensor_names.get());
117 }
118
122 for (auto input : signature.inputs())
123 {
124 auto node = tensor_names->node(input);
125 assert(node != nullptr);
126
127 auto graph_input = graph->inputs()->create();
128
129 auto placeholder_node = loco::must_cast<moco::TFPlaceholder *>(node);
130
131 graph_input->name(input.nodeName());
132
133 // annotate index that should be passed to loco::Pull
134 moco::index(placeholder_node, graph_input->index());
135
136 // This implementation works as "PlaceholderGraphBuilder in Nodes/Placeholder.cpp"
137 // accepts only TF_FLOAT32 as of now.
138 //
139 // TODO Support other types
140 graph_input->dtype(loco::DataType::FLOAT32);
141 }
142
146 for (auto output : signature.outputs())
147 {
148 auto output_node = tensor_names->node(output);
149 assert(output_node);
150
151 // create moco::TFPush for output of graph
152 auto push_node = graph->nodes()->create<moco::TFPush>();
153 push_node->from(output_node); // set input of TFPush to output node
154
155 // set the graph output name and node object
156 auto graph_output = graph->outputs()->create();
157 graph_output->name(output.nodeName());
158 push_node->index(graph_output->index());
159
160 // TODO Support other types
161 graph_output->dtype(loco::DataType::FLOAT32);
162 }
163
164 // validate graph
165 assert(loco::valid(graph));
166}
167
168} // namespace
169
170namespace moco
171{
172
174{
175 // DO NOTHING
176}
177
178std::unique_ptr<loco::Graph> Importer::import(const ModelSignature &signature,
179 tensorflow::GraphDef &tf_graph_def) const
180{
181 auto graph = loco::make_graph();
182
184
185 if (_source != nullptr)
186 {
187 // Use user-defined GraphBuilderSource
188 source_ptr = _source;
189 }
190
191 convert_graph(*source_ptr, signature, tf_graph_def, graph.get());
192
193 return graph;
194}
195
196} // namespace moco
A neural network graph.
Definition Graph.h:161
Node * from(void) const
Definition Nodes.h:58
void index(const GraphOutputIndex &index)
Definition Nodes.cpp:52
Class to store context to build loco graph IR from TensorFlow.
static GraphBuilderRegistry & get()
std::unique_ptr< loco::Graph > import(const ModelSignature &, tensorflow::GraphDef &) const
Definition Importer.cpp:178
Make a value visible to user.
Definition TFPush.h:42
Exception to user.
Definition UserExn.h:42
bool valid(Graph *g, std::unique_ptr< ErrorListener > &&l=nullptr)
Validate a loco graph.
Definition Verifier.cpp:100
std::unique_ptr< Graph > make_graph(void)
Definition Graph.cpp:131
Push * push_node(Graph *g, const GraphOutputIndex &index)
Find a Push node with a given output index.
Definition Nodes.cpp:67
CircleOutput * output_node(loco::Graph *g, const loco::GraphOutputIndex &index)
Find a CircleOutput node with a given output index.
Definition Log.h:23
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
TFPlaceholder * placeholder_node(loco::Graph *g, const loco::GraphInputIndex &idx)
Definition TFNode.cpp:84
FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
virtual const GraphBuilder * lookup(const std::string &op) const =0
Returns registered GraphBuilder pointer for operator (nullptr if not present)
Class to store information to run a model. Normally this info comes from users via CLI params or conf...