ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Driver.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <iostream>
18#include <memory>
19#include <string>
20#include <vector>
21
22#include <arser/arser.h>
23
24#include "CircleModel.h"
25#include "TFLModel.h"
26
27#include <vconone/vconone.h>
28
29void print_version(void)
30{
31 std::cout << "tflite2circle version " << vconone::get_string() << std::endl;
32 std::cout << vconone::get_copyright() << std::endl;
33}
34
35int entry(int argc, char **argv)
36{
37 arser::Arser arser{"tflite2circle is a Tensorflow lite to circle model converter"};
38
41
42 arser.add_argument("tflite").help("Source tflite file path to convert");
43 arser.add_argument("circle").help("Target circle file path");
44
45 try
46 {
47 arser.parse(argc, argv);
48 }
49 catch (const std::runtime_error &err)
50 {
51 std::cerr << err.what() << std::endl;
52 std::cout << arser;
53 return 255;
54 }
55
56 std::string tfl_path = arser.get<std::string>("tflite");
57 std::string circle_path = arser.get<std::string>("circle");
58 // read tflite file
59 tflite2circle::TFLModel tfl_model(tfl_path);
60 if (not tfl_model.verify_data())
61 {
62 std::cerr << "ERROR: Failed to verify tflite '" << tfl_path << "'" << std::endl;
63 return 255;
64 }
65
66 // create flatbuffer builder
67 auto flatbuffer_builder = std::make_unique<flatbuffers::FlatBufferBuilder>(1024);
68
69 // convert tflite to circle
70 const std::vector<char> &raw_data = tfl_model.raw_data();
71 tflite2circle::CircleModel circle_model{flatbuffer_builder, raw_data};
72
73 circle_model.load_offsets(tfl_model.get_model());
74 circle_model.model_build();
75 circle_model.finalize();
76
77 std::ofstream outfile{circle_path, std::ios::binary};
78
79 outfile.write(circle_model.base(), circle_model.size());
80 outfile.close();
81 // TODO find a better way of error handling
82 if (outfile.fail())
83 {
84 std::cerr << "ERROR: Failed to write circle '" << circle_path << "'" << std::endl;
85 return 255;
86 }
87
88 return 0;
89}
static void add_version(Arser &arser, const std::function< void(void)> &func)
Definition arser.h:755
static void add_verbose(Arser &arser)
Definition arser.h:765
int entry(int argc, char **argv)
Definition Driver.cpp:29
void print_version(void)
Definition Driver.cpp:36
Definition arser.h:39
std::string get_copyright(void)
get_copyright will return copyright string
Definition version.cpp:54
std::string get_string(void)
get_string will return string of major.minor.patch (without build)
Definition version.cpp:44