ONE - On-device Neural Engine
Loading...
Searching...
No Matches
circle_eval_diff::TopKMatchPrinter Class Reference

#include <MetricPrinter.h>

Collaboration diagram for circle_eval_diff::TopKMatchPrinter:

Public Member Functions

 TopKMatchPrinter (uint32_t k)
 
void init (const luci::Module *first, const luci::Module *second)
 
void accumulate (const std::vector< std::shared_ptr< Tensor > > &first, const std::vector< std::shared_ptr< Tensor > > &second)
 
void dump (std::ostream &os) const
 
- Public Member Functions inherited from circle_eval_diff::MetricPrinter
virtual ~MetricPrinter ()=default
 

Detailed Description

Definition at line 185 of file MetricPrinter.h.

Constructor & Destructor Documentation

◆ TopKMatchPrinter()

circle_eval_diff::TopKMatchPrinter::TopKMatchPrinter ( uint32_t  k)
inline

Definition at line 188 of file MetricPrinter.h.

188: _k(k) {}

Member Function Documentation

◆ accumulate()

void circle_eval_diff::TopKMatchPrinter::accumulate ( const std::vector< std::shared_ptr< Tensor > > &  first,
const std::vector< std::shared_ptr< Tensor > > &  second 
)
virtual

Implements circle_eval_diff::MetricPrinter.

Definition at line 520 of file MetricPrinter.cpp.

522{
523 assert(first.size() == second.size()); // FIX_CALLER_UNLESS
524 assert(first.size() == _intermediate.size()); // FIX_CALLER_UNLESS
525
526 for (uint32_t output_idx = 0; output_idx < _intermediate.size(); output_idx++)
527 {
528 if (in_skip_list(output_idx))
529 continue;
530
531 const auto first_output = first[output_idx];
532 const auto second_output = second[output_idx];
533
534 // Cast data to fp32 for ease of computation
535 const auto fp32_first_output = fp32(first_output);
536 const auto fp32_second_output = fp32(second_output);
537
538 accum_topk_accuracy(output_idx, fp32_first_output, fp32_second_output);
539 }
540
541 _num_data++;
542}

◆ dump()

void circle_eval_diff::TopKMatchPrinter::dump ( std::ostream &  os) const
virtual

Implements circle_eval_diff::MetricPrinter.

Definition at line 544 of file MetricPrinter.cpp.

545{
546 os << "Ratio of Matched Indices between Top-" << _k << " results of the models" << std::endl;
547
548 for (uint32_t output_idx = 0; output_idx < _intermediate.size(); output_idx++)
549 {
550 if (in_skip_list(output_idx))
551 continue;
552
553 const auto name = _output_names.at(output_idx);
554 const auto sum_of_topk_accuracy = _intermediate.at(output_idx);
555
556 // Compute TopKMatch
557 float mean_topk = sum_of_topk_accuracy / _num_data;
558
559 os << "Mean Top-" << _k << " match ratio for " << name << " is " << mean_topk << std::endl;
560 }
561}

◆ init()

void circle_eval_diff::TopKMatchPrinter::init ( const luci::Module first,
const luci::Module second 
)
virtual

Reimplemented from circle_eval_diff::MetricPrinter.

Definition at line 404 of file MetricPrinter.cpp.

405{
406 THROW_UNLESS(first != nullptr, "Invalid module.");
407 THROW_UNLESS(second != nullptr, "Invalid module.");
408
409 const auto first_output = loco::output_nodes(first->graph());
410 const auto second_output = loco::output_nodes(second->graph());
411
412 assert(first_output.size() == second_output.size()); // FIX_CALLER_UNLESS
413
414 for (uint32_t i = 0; i < first_output.size(); i++)
415 {
416 const auto first_node = loco::must_cast<luci::CircleOutput *>(first_output[i]);
417 const auto second_node = loco::must_cast<luci::CircleOutput *>(second_output[i]);
418
419 // Create places to store intermediate results
420 _intermediate.emplace_back(0.0);
421
422 // Save output names for logging
423 _output_names.emplace_back(first_node->name());
424
425 // If num_elems of an output is less than k,
426 // the output index is added to the skip list
427 if (num_elems(first_node) < _k)
428 {
429 std::cout << "Top-" << _k << "metric for " << first_node->name()
430 << " is ignored, because it has elements less than " << _k << std::endl;
431 _skip_output.emplace_back(i);
432 }
433 }
434}
#define THROW_UNLESS(COND, MSG)
loco::Graph * graph(void) const
provide main graph
Definition Module.cpp:32
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
uint64_t num_elems(const nnfw_tensorinfo *tensor_info)
Get the total number of elements in nnfw_tensorinfo->dims.

References luci::Module::graph(), num_elems(), loco::output_nodes(), and THROW_UNLESS.


The documentation for this class was generated from the following files: