ONE - On-device Neural Engine
Loading...
Searching...
No Matches
tf2circle.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "CustomopConfLoader.h"
18
19#include <moco/LoggingContext.h>
20#include <moco/tf/Frontend.h>
21#include <exo/LoggingContext.h>
22#include <exo/CircleExporter.h>
23
25
27
29#include <hermes/EnvConfig.h>
30
31#include <cassert>
32#include <memory>
33#include <iostream>
34#include <stdexcept>
35#include <string>
36
37namespace
38{
39
40std::unique_ptr<loco::Graph> import(const moco::ModelSignature &sig, const std::string &path)
41{
42 moco::tf::Frontend frontend;
43 return frontend.load(sig, path.c_str(), moco::tf::Frontend::FileType::Binary);
44}
45
46} // namespace
47
48//
49// Logging Support
50//
51namespace
52{
53
54struct Logger final : public hermes::Source
55{
56 Logger(hermes::Context *ctx) { activate(ctx->sources(), ctx->bus()); }
57 ~Logger() { deactivate(); }
58};
59
60struct LoggingContext
61{
62 static hermes::Context *get(void)
63 {
65
66 static hermes::Context *ctx = nullptr;
67
68 if (ctx == nullptr)
69 {
70 ctx = new hermes::Context;
71 ctx->sinks()->append(std::make_unique<hermes::ConsoleReporter>());
72 ctx->config(std::make_unique<EnvConfig>("TF2CIRCLE_Log"));
73 }
74
75 return ctx;
76 }
77};
78
79void print_help()
80{
81 std::cerr << "Usage: tf2circle <path/to/info> <path/to/pb> <path/to/circle/model> " << std::endl
82 << "Options: --customop <path/to/customop.conf>" << std::endl;
83}
84
85} // namespace
86
87#define LOGGER(name) \
88 ::Logger name { ::LoggingContext::get() }
89
90#define INFO(name) HERMES_INFO(name)
91
92namespace
93{
94
95void internal_error(void)
96{
97 std::cerr << "tf2circle: internal compiler error" << std::endl;
98
99 // TODO Explain how to report a bug
100}
101
102} // namespace
103
104namespace
105{
106
107class EntryFunctor
108{
109public:
110 EntryFunctor();
111
112public:
113 ~EntryFunctor();
114
115public:
116 int operator()(int argc, char **argv) const;
117};
118
119EntryFunctor::EntryFunctor()
120{
121 // NOTE Implement initialization here
122}
123
124EntryFunctor::~EntryFunctor()
125{
126 // NOTE Implement finialization here
127}
128
129int EntryFunctor::operator()(int argc, char **argv) const
130{
132
133 // This line allows users to control all the moco-tf loggers via TF2CIRCLE_Log_Frontend
134 moco::LoggingContext::get()->config(std::make_unique<EnvConfig>("TF2CIRCLE_Log_Frontend"));
135 // This line allows users to control all the exo-circle loggers via TF2CIRCLE_Log_Backend
136 exo::LoggingContext::get()->config(std::make_unique<EnvConfig>("TF2CIRCLE_Log_Backend"));
137
138 LOGGER(l);
139
140 // TODO We need better args parsing in future
141 if (!(argc == 4 or argc == 6))
142 {
143 print_help();
144 return 255;
145 }
146
147 std::string info_path{argv[1]};
148 std::string tf_path{argv[2]}; // .pb file
149 std::string circle_path{argv[3]};
150
151 std::cout << "Read '" << info_path << "'" << std::endl;
152
154 {
155 for (const auto &info : nnkit::support::tftestinfo::parse(info_path.c_str()))
156 {
157 switch (info->kind())
158 {
161 sig.shape(info->name(), info->shape());
162 break;
163
166 sig.shape(info->name(), info->shape());
167 break;
168
169 default:
170 throw std::runtime_error{"Unknown kind"};
171 }
172 }
173 }
174
175 if (argc == 6) // optional parameter: path of customop.conf
176 {
177 if (std::string{argv[4]} == "--customop")
178 {
179 tf2circle::load_customop_conf(argv[5], sig);
180 }
181 else
182 {
183 print_help();
184 return 255;
185 }
186 }
187
188 std::cout << "Read '" << info_path << "' - Done" << std::endl;
189
190 std::cout << "Import from '" << tf_path << "'" << std::endl;
191 auto g = import(sig, tf_path);
192 std::cout << "Import from '" << tf_path << "' - Done" << std::endl;
193
194 INFO(l) << "Import Graph" << std::endl;
195 INFO(l) << locop::fmt<locop::Formatter::LinearV1>(g) << std::endl;
196
197 std::cout << "Export into '" << circle_path << "'" << std::endl;
198 exo::CircleExporter(g.get()).dumpToFile(circle_path.c_str());
199 std::cout << "Export into '" << circle_path << "' - Done" << std::endl;
200
201 return 0;
202}
203
204} // namespace
205
206int main(int argc, char **argv)
207{
208 // TODO Add "signal" handler here
209
210 try
211 {
212 EntryFunctor entry;
213 return entry(argc, argv);
214 }
215 catch (...)
216 {
217 // Catch all the exception and print the default error message.
218 internal_error();
219 }
220
221 // EX_SOFTWARE defined in "sysexits.h"
222 return 70;
223}
This file contains functions to parse test.info files in moco/test/tf.
int main(void)
void dumpToFile(const char *path) const
write to a file
Logging controller.
Definition Context.h:40
Source::Registry * sources(void)
Definition Context.h:55
Sink::Registry * sinks(void)
Definition Context.h:64
MessageBus * bus(void)
Definition Context.h:48
const Config * config(void) const
Get the global configuration.
Definition Context.cpp:24
Message Source.
Definition Source.h:35
void deactivate(void)
Definition Source.cpp:49
void activate(Registry *, MessageBus *)
Definition Source.cpp:37
std::unique_ptr< loco::Graph > load(const ModelSignature &, const char *, FileType) const
Definition Frontend.cpp:193
int entry(const int argc, char **argv)
Definition Driver.cpp:53
volatile const char info[]
KnobTrait< K >::ValueType get(void)
std::vector< std::unique_ptr< ParsedTensor > > parse(const char *info_path)
Function to parse test.info.
void load_customop_conf(const std::string &path, moco::ModelSignature &sig)
Loads customop.conf into ModelSignature.
static hermes::Context * get(void)
virtual void append(std::unique_ptr< Sink > &&)=0
static hermes::Context * get(void)
Class to store information to run a model. Normally this info comes from users via CLI params or conf...
void add_input(const TensorName &input)
void add_output(const TensorName &output)
void shape(const std::string &node_name, const angkor::TensorShape &shape)
Adds node name and its shape provided from user.
const std::string & name() const
Definition Names.h:60
#define LOGGER(name)
Definition tf2circle.cpp:87
#define INFO(name)
Definition tf2circle.cpp:90