ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Driver.cpp File Reference
#include "CustomopConfLoader.h"
#include <moco/LoggingContext.h>
#include <moco/tf/Frontend.h>
#include <exo/LoggingContext.h>
#include <exo/TFLExporter.h>
#include <nnkit/support/tftestinfo/TensorInfoParser.h>
#include <locop/FormattedGraph.h>
#include <hermes/ConsoleReporter.h>
#include <hermes/EnvConfig.h>
#include <cassert>
#include <memory>
#include <iostream>
#include <stdexcept>
#include <string>

Go to the source code of this file.

Macros

#define LOGGER(name)    ::Logger name { ::LoggingContext::get() }
 
#define INFO(name)   HERMES_INFO(name)
 

Functions

int main (int argc, char **argv)
 

Macro Definition Documentation

◆ INFO

#define INFO (   name)    HERMES_INFO(name)

Definition at line 90 of file Driver.cpp.

◆ LOGGER

#define LOGGER (   name)     ::Logger name { ::LoggingContext::get() }

Definition at line 87 of file Driver.cpp.

88 { ::LoggingContext::get() }

Function Documentation

◆ main()

int main ( int  argc,
char **  argv 
)

Definition at line 92 of file Driver.cpp.

93{
95
96 // This line allows users to control all the moco-tf loggers via TF2TFLITE_Log_Frontend
97 moco::LoggingContext::get()->config(std::make_unique<EnvConfig>("TF2TFLITE_Log_Frontend"));
98 // This line allows users to control all the exo-tflite loggers via TF2TFLITE_Log_Backend
99 exo::LoggingContext::get()->config(std::make_unique<EnvConfig>("TF2TFLITE_Log_Backend"));
100
101 LOGGER(l);
102
103 // TODO We need better args parsing in future
104 if (!(argc == 4 or argc == 6))
105 {
106 print_help();
107 return 255;
108 }
109
110 std::string info_path{argv[1]};
111 std::string tf_path{argv[2]}; // .pb file
112 std::string tflite_path{argv[3]};
113
114 std::cout << "Read '" << info_path << "'" << std::endl;
115
117 {
118 for (const auto &info : nnkit::support::tftestinfo::parse(info_path.c_str()))
119 {
120 switch (info->kind())
121 {
124 sig.shape(info->name(), info->shape());
125 break;
126
129 sig.shape(info->name(), info->shape());
130 break;
131
132 default:
133 throw std::runtime_error{"Unknown kind"};
134 }
135 }
136 }
137
138 if (argc == 6) // optional parameter: path of customop.conf
139 {
140 if (std::string{argv[4]} == "--customop")
141 {
142 tf2tflite::load_customop_conf(argv[5], sig);
143 }
144 else
145 {
146 print_help();
147 return 255;
148 }
149 }
150
151 std::cout << "Read '" << info_path << "' - Done" << std::endl;
152
153 std::cout << "Import from '" << tf_path << "'" << std::endl;
154 auto g = import(sig, tf_path);
155 std::cout << "Import from '" << tf_path << "' - Done" << std::endl;
156
157 INFO(l) << "Import Graph" << std::endl;
158 INFO(l) << locop::fmt<locop::Formatter::LinearV1>(g) << std::endl;
159
160 std::cout << "Export into '" << tflite_path << "'" << std::endl;
161 exo::TFLExporter(g.get()).dumpToFile(tflite_path.c_str());
162 std::cout << "Export into '" << tflite_path << "' - Done" << std::endl;
163
164 return 0;
165}
#define LOGGER(name)
Definition Log.h:65
#define INFO(name)
Definition Log.h:68
void dumpToFile(const char *path) const
write to a file
const Config * config(void) const
Get the global configuration.
Definition Context.cpp:24
volatile const char info[]
std::vector< std::unique_ptr< ParsedTensor > > parse(const char *info_path)
Function to parse test.info.
void load_customop_conf(const std::string &path, moco::ModelSignature &sig)
Loads customop.conf into ModelSignature.
static hermes::Context * get(void)
static hermes::Context * get(void)
Class to store information to run a model. Normally this info comes from users via CLI params or conf...
void add_input(const TensorName &input)
void add_output(const TensorName &output)
void shape(const std::string &node_name, const angkor::TensorShape &shape)
Adds node name and its shape provided from user.
const std::string & name() const
Definition Names.h:60

References moco::ModelSignature::add_input(), moco::ModelSignature::add_output(), hermes::Context::config(), exo::TFLExporter::dumpToFile(), exo::LoggingContext::get(), moco::LoggingContext::get(), INFO, info, nnkit::support::tftestinfo::ParsedTensor::Input, tf2tflite::load_customop_conf(), LOGGER, moco::TensorName::name(), nnkit::support::tftestinfo::ParsedTensor::Output, nnkit::support::tftestinfo::parse(), and moco::ModelSignature::shape().