ONE - On-device Neural Engine
Loading...
Searching...
No Matches
circle_eval_diff::CircleEvalDiff Class Referencefinal

#include <CircleEvalDiff.h>

Data Structures

struct  Context
 

Public Member Functions

 CircleEvalDiff (std::unique_ptr< Context > &&ctx)
 
 ~CircleEvalDiff ()
 
void init ()
 
void evalDiff (void) const
 

Detailed Description

Definition at line 47 of file CircleEvalDiff.h.

Constructor & Destructor Documentation

◆ CircleEvalDiff()

circle_eval_diff::CircleEvalDiff::CircleEvalDiff ( std::unique_ptr< Context > &&  ctx)

Definition at line 142 of file CircleEvalDiff.cpp.

142 : _ctx(std::move(ctx))
143{
144 // DO NOTHING
145}

◆ ~CircleEvalDiff()

circle_eval_diff::CircleEvalDiff::~CircleEvalDiff ( )
default

Member Function Documentation

◆ evalDiff()

void circle_eval_diff::CircleEvalDiff::evalDiff ( void  ) const

Definition at line 201 of file CircleEvalDiff.cpp.

202{
203 auto first_input_loader = circle_eval_diff::makeDataLoader(
204 _ctx->first_input_data_path, _ctx->input_format, ::inputs_of(_first_module.get()));
205 auto second_input_loader = circle_eval_diff::makeDataLoader(
206 _ctx->second_input_data_path, _ctx->input_format, ::inputs_of(_second_module.get()));
207
208 for (uint32_t data_idx = 0; data_idx < first_input_loader->size(); data_idx++)
209 {
210 std::cout << "Evaluating " << data_idx << "'th data" << std::endl;
211
212 auto first_data = first_input_loader->get(data_idx);
213 auto second_data = second_input_loader->get(data_idx);
214
215 auto first_output = interpret(_first_module.get(), first_data);
216 auto second_output = interpret(_second_module.get(), second_data);
217
218 for (auto &metric : _metrics)
219 {
220 metric->accumulate(first_output, second_output);
221 }
222
223 if (_ctx.get()->output_prefix.empty())
224 continue;
225
226 for (uint32_t i = 0; i < first_output.size(); i++)
227 {
228 auto out = first_output[i];
229 writeDataToFile(_ctx.get()->output_prefix + "." + std::to_string(data_idx) + ".first.output" +
230 std::to_string(i),
231 (char *)(out->buffer()), out->byte_size());
232 }
233 for (uint32_t i = 0; i < second_output.size(); i++)
234 {
235 auto out = second_output[i];
236 writeDataToFile(_ctx.get()->output_prefix + "." + std::to_string(data_idx) +
237 ".second.output" + std::to_string(i),
238 (char *)(out->buffer()), out->byte_size());
239 }
240 }
241
242 for (auto &metric : _metrics)
243 {
244 std::cout << metric.get() << std::endl;
245 }
246}
std::vector< std::shared_ptr< Tensor > > interpret(const luci::Module *module, const InputDataLoader::Data &data)
std::unique_ptr< InputDataLoader > makeDataLoader(const std::string &file_path, const InputFormat &format, const std::vector< loco::Node * > &input_nodes)
void writeDataToFile(const std::string &file_path, const std::string &data)
write data to file_path

References circle_eval_diff::interpret(), and circle_eval_diff::makeDataLoader().

Referenced by entry().

◆ init()

void circle_eval_diff::CircleEvalDiff::init ( )

Definition at line 149 of file CircleEvalDiff.cpp.

150{
151 _first_module = import(_ctx->first_model_path);
152 _second_module = import(_ctx->second_model_path);
153
154 // Check modules have the same output signature (dtype/shape)
155 // Exception will be thrown if they have different signature
156 checkOutputs(_first_module.get(), _second_module.get());
157
158 // Set metric
159 std::unique_ptr<MetricPrinter> metric;
160 for (auto metric : _ctx->metric)
161 {
162 switch (metric)
163 {
164 case Metric::MAE:
165 {
166 _metrics.emplace_back(std::make_unique<MAEPrinter>());
167 break;
168 }
169 case Metric::MAPE:
170 {
171 _metrics.emplace_back(std::make_unique<MAPEPrinter>());
172 break;
173 }
174 case Metric::MPEIR:
175 {
176 _metrics.emplace_back(std::make_unique<MPEIRPrinter>());
177 break;
178 }
179 case Metric::MTOP1:
180 {
181 _metrics.emplace_back(std::make_unique<TopKMatchPrinter>(1));
182 break;
183 }
184 case Metric::MTOP5:
185 {
186 _metrics.emplace_back(std::make_unique<TopKMatchPrinter>(5));
187 break;
188 }
189 case Metric::MSE:
190 {
191 _metrics.emplace_back(std::make_unique<MSEPrinter>());
192 break;
193 }
194 default:
195 throw std::runtime_error("Unsupported metric.");
196 }
197 _metrics.back()->init(_first_module.get(), _second_module.get());
198 }
199}

References circle_eval_diff::MAE, circle_eval_diff::MAPE, circle_eval_diff::MPEIR, circle_eval_diff::MSE, circle_eval_diff::MTOP1, and circle_eval_diff::MTOP5.

Referenced by entry().


The documentation for this class was generated from the following files: