ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnfw::misc::tensor::Comparator Class Reference

Class to compare two tensors (expected and obtained to compare) More...

#include <Comparator.h>

Data Structures

struct  Observer
 Struct to observe comparison results. More...
 

Public Member Functions

 Comparator (const std::function< bool(float lhs, float rhs)> &fn)
 Construct a new Comparator object.
 
std::vector< Diff< float > > compare (const Shape &shape, const Reader< float > &expected, const Reader< float > &obtained, Observer *observer=nullptr) const
 Compare two tensors.
 

Detailed Description

Class to compare two tensors (expected and obtained to compare)

Definition at line 45 of file Comparator.h.

Constructor & Destructor Documentation

◆ Comparator()

nnfw::misc::tensor::Comparator::Comparator ( const std::function< bool(float lhs, float rhs)> &  fn)
inline

Construct a new Comparator object.

Parameters
[in]fnFunction that compares two float values

Definition at line 52 of file Comparator.h.

52 : _compare_fn{fn}
53 {
54 // DO NOTHING
55 }

Member Function Documentation

◆ compare()

std::vector< Diff< float > > nnfw::misc::tensor::Comparator::compare ( const Shape shape,
const Reader< float > &  expected,
const Reader< float > &  obtained,
Observer observer = nullptr 
) const

Compare two tensors.

Parameters
[in]shapeShape of two tensors
[in]expectedReader<float> object that accesses expected tensor
[in]obtainedReader<float> object that accesses obtained tensor
[in]observerObserver notified of expected value and obtained value at every index
Returns
std::vector<Diff<float>> containing information of failed comparison

Definition at line 29 of file Comparator.cpp.

32{
33 std::vector<Diff<float>> res;
34
35 zip(shape, expected, obtained) <<
36 [&](const Index &index, float expected_value, float obtained_value) {
37 if (!_compare_fn(expected_value, obtained_value))
38 {
39 res.emplace_back(index, expected_value, obtained_value);
40 }
41
42 // Update max_diff_index, if necessary
43 if (observer != nullptr)
44 {
45 observer->notify(index, expected_value, obtained_value);
46 }
47 };
48
49 return res;
50}
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
Zipper< T > zip(const Shape &shape, const Reader< T > &lhs, const Reader< T > &rhs)
Get Zipper object constructed using passed params.
Definition Zipper.h:95

References nnfw::misc::tensor::Comparator::Observer::notify(), and nnfw::misc::tensor::zip().


The documentation for this class was generated from the following files: