ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Driver.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "CustomopConfLoader.h"
18
19#include <moco/LoggingContext.h>
20#include <moco/tf/Frontend.h>
21#include <exo/LoggingContext.h>
22#include <exo/TFLExporter.h>
23
25
27
29#include <hermes/EnvConfig.h>
30
31#include <cassert>
32#include <memory>
33#include <iostream>
34#include <stdexcept>
35#include <string>
36
37namespace
38{
39
40std::unique_ptr<loco::Graph> import(const moco::ModelSignature &sig, const std::string &path)
41{
42 moco::tf::Frontend frontend;
43 return frontend.load(sig, path.c_str(), moco::tf::Frontend::FileType::Binary);
44}
45
46} // namespace
47
48//
49// Logging Support
50//
51namespace
52{
53
54struct Logger final : public hermes::Source
55{
56 Logger(hermes::Context *ctx) { activate(ctx->sources(), ctx->bus()); }
57 ~Logger() { deactivate(); }
58};
59
60struct LoggingContext
61{
62 static hermes::Context *get(void)
63 {
65
66 static hermes::Context *ctx = nullptr;
67
68 if (ctx == nullptr)
69 {
70 ctx = new hermes::Context;
71 ctx->sinks()->append(std::make_unique<hermes::ConsoleReporter>());
72 ctx->config(std::make_unique<EnvConfig>("TF2TFLITE_Log"));
73 }
74
75 return ctx;
76 }
77};
78
79void print_help()
80{
81 std::cerr << "Usage: tf2tflite <path/to/info> <path/to/pb> <path/to/tflite/model> " << std::endl
82 << "Options: --customop <path/to/customop.conf>" << std::endl;
83}
84
85} // namespace
86
87#define LOGGER(name) \
88 ::Logger name { ::LoggingContext::get() }
89
90#define INFO(name) HERMES_INFO(name)
91
92int main(int argc, char **argv)
93{
95
96 // This line allows users to control all the moco-tf loggers via TF2TFLITE_Log_Frontend
97 moco::LoggingContext::get()->config(std::make_unique<EnvConfig>("TF2TFLITE_Log_Frontend"));
98 // This line allows users to control all the exo-tflite loggers via TF2TFLITE_Log_Backend
99 exo::LoggingContext::get()->config(std::make_unique<EnvConfig>("TF2TFLITE_Log_Backend"));
100
101 LOGGER(l);
102
103 // TODO We need better args parsing in future
104 if (!(argc == 4 or argc == 6))
105 {
106 print_help();
107 return 255;
108 }
109
110 std::string info_path{argv[1]};
111 std::string tf_path{argv[2]}; // .pb file
112 std::string tflite_path{argv[3]};
113
114 std::cout << "Read '" << info_path << "'" << std::endl;
115
117 {
118 for (const auto &info : nnkit::support::tftestinfo::parse(info_path.c_str()))
119 {
120 switch (info->kind())
121 {
124 sig.shape(info->name(), info->shape());
125 break;
126
129 sig.shape(info->name(), info->shape());
130 break;
131
132 default:
133 throw std::runtime_error{"Unknown kind"};
134 }
135 }
136 }
137
138 if (argc == 6) // optional parameter: path of customop.conf
139 {
140 if (std::string{argv[4]} == "--customop")
141 {
142 tf2tflite::load_customop_conf(argv[5], sig);
143 }
144 else
145 {
146 print_help();
147 return 255;
148 }
149 }
150
151 std::cout << "Read '" << info_path << "' - Done" << std::endl;
152
153 std::cout << "Import from '" << tf_path << "'" << std::endl;
154 auto g = import(sig, tf_path);
155 std::cout << "Import from '" << tf_path << "' - Done" << std::endl;
156
157 INFO(l) << "Import Graph" << std::endl;
158 INFO(l) << locop::fmt<locop::Formatter::LinearV1>(g) << std::endl;
159
160 std::cout << "Export into '" << tflite_path << "'" << std::endl;
161 exo::TFLExporter(g.get()).dumpToFile(tflite_path.c_str());
162 std::cout << "Export into '" << tflite_path << "' - Done" << std::endl;
163
164 return 0;
165}
This file contains functions to parse test.info files in moco/test/tf.
int main(void)
#define LOGGER(name)
Definition Log.h:65
#define INFO(name)
Definition Log.h:68
void dumpToFile(const char *path) const
write to a file
Logging controller.
Definition Context.h:40
Source::Registry * sources(void)
Definition Context.h:55
Sink::Registry * sinks(void)
Definition Context.h:64
MessageBus * bus(void)
Definition Context.h:48
const Config * config(void) const
Get the global configuration.
Definition Context.cpp:24
Message Source.
Definition Source.h:35
void deactivate(void)
Definition Source.cpp:49
void activate(Registry *, MessageBus *)
Definition Source.cpp:37
std::unique_ptr< loco::Graph > load(const ModelSignature &, const char *, FileType) const
Definition Frontend.cpp:193
volatile const char info[]
KnobTrait< K >::ValueType get(void)
std::vector< std::unique_ptr< ParsedTensor > > parse(const char *info_path)
Function to parse test.info.
void load_customop_conf(const std::string &path, moco::ModelSignature &sig)
Loads customop.conf into ModelSignature.
static hermes::Context * get(void)
virtual void append(std::unique_ptr< Sink > &&)=0
static hermes::Context * get(void)
Class to store information to run a model. Normally this info comes from users via CLI params or conf...
void add_input(const TensorName &input)
void add_output(const TensorName &output)
void shape(const std::string &node_name, const angkor::TensorShape &shape)
Adds node name and its shape provided from user.
const std::string & name() const
Definition Names.h:60