ONE - On-device Neural Engine
Loading...
Searching...
No Matches
tf2nnpkg.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "filesystem.h"
18
19#include <moco/LoggingContext.h>
20#include <moco/tf/Frontend.h>
21#include <exo/LoggingContext.h>
22#include <exo/CircleExporter.h>
23
25
27
29#include <hermes/EnvConfig.h>
30
31#include <memory>
32#include <iostream>
33#include <fstream>
34#include <functional>
35#include <stdexcept>
36#include <string>
37#include <vector>
38
39namespace
40{
41
42std::unique_ptr<loco::Graph> import(const moco::ModelSignature &sig, const std::string &path)
43{
44 moco::tf::Frontend frontend;
45 return frontend.load(sig, path.c_str(), moco::tf::Frontend::FileType::Binary);
46}
47
48} // namespace
49
50//
51// Logging Support
52//
53namespace
54{
55
56struct Logger final : public hermes::Source
57{
58 Logger(hermes::Context *ctx) { activate(ctx->sources(), ctx->bus()); }
59 ~Logger() { deactivate(); }
60};
61
62struct LoggingContext
63{
64 static hermes::Context *get(void)
65 {
67
68 static hermes::Context *ctx = nullptr;
69
70 if (ctx == nullptr)
71 {
72 ctx = new hermes::Context;
73 ctx->sinks()->append(std::make_unique<hermes::ConsoleReporter>());
74 ctx->config(std::make_unique<EnvConfig>("TF2NNPKG_Log"));
75 }
76
77 return ctx;
78 }
79};
80
81void print_help()
82{
83 std::cerr << "Usage:" << std::endl;
84 std::cerr << " tf2nnpkg --info <path/to/info>" << std::endl;
85 std::cerr << " --graphdef <path/to/pb>" << std::endl;
86 std::cerr << " -o <path/to/package/dir>" << std::endl;
87}
88
89} // namespace
90
91#define LOGGER(name) \
92 ::Logger name { ::LoggingContext::get() }
93
94#define INFO(name) HERMES_INFO(name)
95
96namespace
97{
98
99void internal_error(void)
100{
101 std::cerr << "tf2nnpkg: internal compiler error" << std::endl;
102
103 // TODO Explain how to report a bug
104}
105
106} // namespace
107
108namespace
109{
110
111std::string extract_modelname(std::string tf_path)
112{
113 auto filename = filesystem::basename(tf_path);
114 // TODO Find better way
115 const std::string key = ".pb";
116 auto suffix_index = filename.find(key);
117 assert(suffix_index != std::string::npos);
118 assert(suffix_index + key.size() == filename.size());
119
120 return filename.substr(0, suffix_index);
121}
122
123class EntryFunctor
124{
125public:
126 EntryFunctor();
127
128public:
129 ~EntryFunctor();
130
131public:
132 int operator()(int argc, char **argv) const;
133};
134
135EntryFunctor::EntryFunctor()
136{
137 // NOTE Implement initialization here
138}
139
140EntryFunctor::~EntryFunctor()
141{
142 // NOTE Implement finialization here
143}
144
145int EntryFunctor::operator()(int argc, char **argv) const
146{
148
149 // This line allows users to control all the moco-tf loggers via TF2NNPKG_Log_Frontend
150 moco::LoggingContext::get()->config(std::make_unique<EnvConfig>("TF2NNPKG_Log_Frontend"));
151 // This line allows users to control all the exo-circle loggers via TF2NNPKG_Log_Backend
152 exo::LoggingContext::get()->config(std::make_unique<EnvConfig>("TF2NNPKG_Log_Backend"));
153
154 LOGGER(l);
155
156 // Simple argument parser (based on map)
157 std::map<std::string, std::function<void(const std::string &arg)>> argparse;
158
159 std::string arg_info;
160 std::string arg_graphdef;
161 std::string arg_output;
162
163 argparse["--info"] = [&](const std::string &arg) { arg_info = arg; };
164 argparse["--graphdef"] = [&](const std::string &arg) { arg_graphdef = arg; };
165 argparse["-o"] = [&](const std::string &arg) { arg_output = arg; };
166
167 // TODO We need better args parsing in future
168
169 for (int n = 1; n < argc; n += 2)
170 {
171 const std::string tag{argv[n]};
172 const std::string arg{argv[n + 1]};
173
174 auto it = argparse.find(tag);
175 if (it == argparse.end())
176 {
177 std::cerr << "Option '" << tag << "' is not supported" << std::endl;
178 print_help();
179 return 255;
180 }
181
182 it->second(arg);
183 }
184 if (arg_info.empty() || arg_graphdef.empty() || arg_output.empty())
185 {
186 print_help();
187 return 255;
188 }
189
190 // Input paths
191 std::string info_path = arg_info;
192 std::string tf_path = arg_graphdef; // .pb file
193
194 // Output paths
195 std::string outdir_path = arg_output;
196 std::string modelname = extract_modelname(filesystem::normalize_path(tf_path));
197 std::string nnpkg_path = filesystem::join(outdir_path, modelname);
198 std::string model_filename = modelname + ".circle";
199 std::string metadata_path = filesystem::join(nnpkg_path, "metadata");
200 std::string circle_path = filesystem::join(nnpkg_path, model_filename);
201 std::string manifest_path = filesystem::join(metadata_path, "MANIFEST");
202
203 std::cout << "Read '" << info_path << "'" << std::endl;
204
206 {
207 for (const auto &info : nnkit::support::tftestinfo::parse(info_path.c_str()))
208 {
209 switch (info->kind())
210 {
213 sig.shape(info->name(), info->shape());
214 break;
215
218 sig.shape(info->name(), info->shape());
219 break;
220
221 default:
222 throw std::runtime_error{"Unknown kind"};
223 }
224 }
225 }
226
227 std::cout << "Read '" << info_path << "' - Done" << std::endl;
228
229 std::cout << "Import from '" << tf_path << "'" << std::endl;
230 auto g = import(sig, tf_path);
231 std::cout << "Import from '" << tf_path << "' - Done" << std::endl;
232
233 INFO(l) << "Import Graph" << std::endl;
234 INFO(l) << locop::fmt<locop::Formatter::LinearV1>(g) << std::endl;
235
236 if (not filesystem::is_dir(outdir_path))
237 {
238 std::cout << "Make output directory '" << outdir_path << "'" << std::endl;
239 if (not filesystem::mkdir(outdir_path))
240 throw std::runtime_error("Fail to make directory " + outdir_path);
241 std::cout << "Make output directory '" << outdir_path << "' - Done" << std::endl;
242 }
243
244 if (not filesystem::is_dir(nnpkg_path))
245 {
246 std::cout << "Make package directory '" << nnpkg_path << "'" << std::endl;
247 if (not filesystem::mkdir(nnpkg_path))
248 throw std::runtime_error("Fail to make directory " + nnpkg_path);
249 std::cout << "Make package directory '" << nnpkg_path << "' - Done" << std::endl;
250 }
251
252 std::cout << "Export into '" << circle_path << "'" << std::endl;
253 exo::CircleExporter(g.get()).dumpToFile(circle_path.c_str());
254 std::cout << "Export into '" << circle_path << "' - Done" << std::endl;
255
256 if (not filesystem::is_dir(metadata_path))
257 {
258 std::cout << "Make metadata directory '" << metadata_path << "'" << std::endl;
259 if (not filesystem::mkdir(metadata_path))
260 throw std::runtime_error("Fail to make directory " + metadata_path);
261 std::cout << "Make metadata directory '" << metadata_path << "' - Done" << std::endl;
262 }
263
264 std::cout << "Make manifest file '" << manifest_path << "'" << std::endl;
265 std::ofstream manifest_file;
266 manifest_file.open(manifest_path, std::ios::out | std::ios::binary);
267 manifest_file << "{\n";
268 manifest_file << " \"major-version\" : \"1\",\n";
269 manifest_file << " \"minor-version\" : \"0\",\n";
270 manifest_file << " \"patch-version\" : \"0\",\n";
271 manifest_file << " \"models\" : [ \"" + model_filename + "\" ],\n";
272 manifest_file << " \"model-types\" : [ \"circle\" ]\n";
273 manifest_file << "}";
274 manifest_file.close();
275 std::cout << "Make manifest file '" << manifest_path << "' - Done" << std::endl;
276
277 return 0;
278}
279
280} // namespace
281
282int main(int argc, char **argv)
283{
284 // TODO Add "signal" handler here
285
286 try
287 {
288 EntryFunctor entry;
289 return entry(argc, argv);
290 }
291 catch (...)
292 {
293 // Catch all the exception and print the default error message.
294 internal_error();
295 }
296
297 // EX_SOFTWARE defined in "sysexits.h"
298 return 70;
299}
This file contains functions to parse test.info files in moco/test/tf.
int main(void)
void dumpToFile(const char *path) const
write to a file
Logging controller.
Definition Context.h:40
Source::Registry * sources(void)
Definition Context.h:55
Sink::Registry * sinks(void)
Definition Context.h:64
MessageBus * bus(void)
Definition Context.h:48
const Config * config(void) const
Get the global configuration.
Definition Context.cpp:24
Message Source.
Definition Source.h:35
void deactivate(void)
Definition Source.cpp:49
void activate(Registry *, MessageBus *)
Definition Source.cpp:37
std::unique_ptr< loco::Graph > load(const ModelSignature &, const char *, FileType) const
Definition Frontend.cpp:193
int entry(const int argc, char **argv)
Definition Driver.cpp:53
volatile const char info[]
KnobTrait< K >::ValueType get(void)
bool mkdir(const std::string &path)
bool is_dir(const std::string &path)
std::string join(const std::string &path1, const std::string &path2)
std::string normalize_path(const std::string &path)
Normalize compatible separator in path to default separator.
std::string basename(const std::string &path)
std::vector< std::unique_ptr< ParsedTensor > > parse(const char *info_path)
Function to parse test.info.
static hermes::Context * get(void)
virtual void append(std::unique_ptr< Sink > &&)=0
static hermes::Context * get(void)
Class to store information to run a model. Normally this info comes from users via CLI params or conf...
void add_input(const TensorName &input)
void add_output(const TensorName &output)
void shape(const std::string &node_name, const angkor::TensorShape &shape)
Adds node name and its shape provided from user.
const std::string & name() const
Definition Names.h:60
#define LOGGER(name)
Definition tf2nnpkg.cpp:91
#define INFO(name)
Definition tf2nnpkg.cpp:94