ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Driver.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "OMInterpreter.h"
18#include "arser/arser.h"
19
20#include <stdexcept>
21#include <cstdlib>
22#include <fstream>
23#include <vector>
24#include <string>
25#include <iostream>
26#include <random>
27
28namespace
29{
30
31using DataBuffer = std::vector<char>;
32
33void generateRandomData(char *data, size_t data_size)
34{
35 std::random_device rd;
36 std::mt19937 gen(rd());
37 std::uniform_int_distribution<uint8_t> dist(0, 255);
38 for (size_t i = 0; i < data_size; ++i)
39 data[i] = static_cast<char>(dist(gen));
40}
41
42void readDataFromFile(const std::string &filename, char *data, size_t data_size)
43{
44 std::ifstream fs(filename, std::ifstream::binary);
45 if (fs.fail())
46 throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
47 if (fs.read(data, data_size).fail())
48 throw std::runtime_error("Failed to read data from file \"" + filename + "\".\n");
49}
50
51void writeDataToFile(const std::string &filename, const char *data, size_t data_size)
52{
53 std::ofstream fs(filename, std::ofstream::binary);
54 if (fs.fail())
55 throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
56 if (fs.write(data, data_size).fail())
57 {
58 throw std::runtime_error("Failed to write data to file \"" + filename + "\".\n");
59 }
60}
61
62} // namespace
63
64/*
65 * @brief EvalDriver main
66 *
67 * Driver for testing luci-inerpreter
68 *
69 */
70int entry(int argc, char **argv)
71{
72 // Parse command line arguments using arser
74 arser.add_argument("--model")
75 .type(arser::DataType::STR)
76 .required(true)
77 .help("Path to model.circle file");
78 arser.add_argument("--input_prefix")
79 .type(arser::DataType::STR)
80 .help("Prefix for input files (generates random inputs if not provided)");
81 arser.add_argument("--output_prefix").type(arser::DataType::STR).help("Prefix for output files");
82
83 try
84 {
85 arser.parse(argc, argv);
86 }
87 catch (const std::runtime_error &err)
88 {
89 std::cerr << err.what() << std::endl;
90 std::cerr << arser;
91 return EXIT_FAILURE;
92 }
93
94 const auto filename = arser.get<std::string>("--model");
95 std::string input_prefix;
96 std::string output_prefix;
97
98 if (arser["--input_prefix"])
99 {
100 input_prefix = arser.get<std::string>("--input_prefix");
101 }
102 if (arser["--output_prefix"])
103 {
104 output_prefix = arser.get<std::string>("--output_prefix");
105 }
106 const bool auto_input = !arser["--input_prefix"];
107 int32_t num_inputs = 1; // Default number of inputs
108
109 std::ifstream file(filename, std::ios::binary | std::ios::in);
110 if (!file.good())
111 {
112 std::string errmsg = "Failed to open file";
113 throw std::runtime_error(errmsg.c_str());
114 }
115
116 file.seekg(0, std::ios::end);
117 auto fileSize = file.tellg();
118 file.seekg(0, std::ios::beg);
119
120 // reserve capacity
121 DataBuffer model_data(fileSize);
122
123 // read the data
124 file.read(model_data.data(), fileSize);
125 if (file.fail())
126 {
127 std::string errmsg = "Failed to read file";
128 throw std::runtime_error(errmsg.c_str());
129 }
130
131 // Create interpreter.
132 onert_micro::OMInterpreter interpreter;
134 interpreter.importModel(model_data.data(), config);
135
136 num_inputs = interpreter.getNumberOfInputs(); // To initialize input buffers
137
138 // Set input.
139 // Data for n'th input is read from ${input_prefix}n
140 // (ex: Add.circle.input0, Add.circle.input1 ..)
141 int num_inference = 1;
142 for (int j = 0; j < num_inference; ++j)
143 {
144 interpreter.reset();
145 interpreter.allocateInputs();
146 for (int32_t i = 0; i < num_inputs; i++)
147 {
148 auto input_data = reinterpret_cast<char *>(interpreter.getInputDataAt(i));
149 size_t input_size = interpreter.getInputSizeAt(i);
150
151 if (auto_input)
152 {
153 generateRandomData(input_data, input_size);
154 }
155 else
156 {
157 readDataFromFile(input_prefix + std::to_string(i), input_data, input_size);
158 }
159 }
160
161 // Do inference.
162 interpreter.run(config);
163 }
164
165 // Get output.
166 int num_outputs = 1;
167 for (int i = 0; i < num_outputs; i++)
168 {
169 auto data = interpreter.getOutputDataAt(i);
170 size_t output_size = interpreter.getOutputSizeAt(i);
171
172 if (arser["--output_prefix"])
173 {
174 writeDataToFile(output_prefix + std::to_string(i), reinterpret_cast<char *>(data),
175 output_size);
176 }
177 // Otherwise, output remains in interpreter memory
178 }
179 interpreter.reset();
180 return EXIT_SUCCESS;
181}
182
183int entry(int argc, char **argv);
184
185#ifdef NDEBUG
186int main(int argc, char **argv)
187{
188 try
189 {
190 return entry(argc, argv);
191 }
192 catch (const std::exception &e)
193 {
194 std::cerr << "ERROR: " << e.what() << std::endl;
195 }
196
197 return 255;
198}
199#else // NDEBUG
200int main(int argc, char **argv)
201{
202 // NOTE main does not catch internal exceptions for debug build to make it easy to
203 // check the stacktrace with a debugger
204 return entry(argc, argv);
205}
206#endif // !NDEBUG
int main(void)
int entry(int argc, char **argv)
Definition Driver.cpp:29
Definition arser.h:39
void writeDataToFile(const std::string &file_path, const std::string &data)
write data to file_path
void readDataFromFile(const std::string &filename, std::vector< char > &data, size_t data_size)
Definition Utils.cpp:65
std::vector< char > DataBuffer