ONE - On-device Neural Engine
Loading...
Searching...
No Matches
record_minmax::HDF5Iterator Class Referencefinal

#include <HDF5Iterator.h>

Collaboration diagram for record_minmax::HDF5Iterator:

Public Member Functions

 HDF5Iterator (const std::string &file_path, luci::Module *module)
 
bool hasNext () const override
 
std::vector< DataBuffernext () override
 
bool check_type_shape () const override
 
- Public Member Functions inherited from record_minmax::DataSetIterator
virtual ~DataSetIterator ()=default
 

Detailed Description

Definition at line 33 of file HDF5Iterator.h.

Constructor & Destructor Documentation

◆ HDF5Iterator()

record_minmax::HDF5Iterator::HDF5Iterator ( const std::string &  file_path,
luci::Module module 
)

Definition at line 29 of file HDF5Iterator.cpp.

30 : _importer(file_path)
31{
32 try
33 {
34 _importer.importGroup("value");
35
36 _is_raw_data = _importer.isRawData();
37
38 _num_data = _importer.numData();
39 }
40 catch (const H5::Exception &e)
41 {
42 H5::Exception::printErrorStack();
43 throw std::runtime_error("HDF5 error occurred during initialization.");
44 }
45
46 auto input_nodes = loco::input_nodes(module->graph());
47 for (auto input_node : input_nodes)
48 {
49 const auto cnode = loco::must_cast<const luci::CircleInput *>(input_node);
50 _input_nodes.emplace_back(cnode);
51 }
52}
int32_t numData() const
void importGroup(const std::string &group)
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71

References dio::hdf5::HDF5Importer::importGroup(), loco::input_nodes(), dio::hdf5::HDF5Importer::isRawData(), and dio::hdf5::HDF5Importer::numData().

Member Function Documentation

◆ check_type_shape()

bool record_minmax::HDF5Iterator::check_type_shape ( ) const
overridevirtual

Implements record_minmax::DataSetIterator.

Definition at line 94 of file HDF5Iterator.cpp.

95{
96 // If it's raw data, we don't need to check type and shape
97 return not _is_raw_data;
98}

Referenced by next().

◆ hasNext()

bool record_minmax::HDF5Iterator::hasNext ( ) const
overridevirtual

Implements record_minmax::DataSetIterator.

Definition at line 54 of file HDF5Iterator.cpp.

54{ return _curr_idx < _num_data; }

◆ next()

std::vector< DataBuffer > record_minmax::HDF5Iterator::next ( )
overridevirtual

Implements record_minmax::DataSetIterator.

Definition at line 56 of file HDF5Iterator.cpp.

57{
58 std::vector<DataBuffer> res;
59
60 try
61 {
62 for (int32_t input_idx = 0; input_idx < _importer.numInputs(_curr_idx); input_idx++)
63 {
65
66 const auto input_node = _input_nodes.at(input_idx);
67 const auto input_size = getTensorSize(input_node);
68 buf.data.resize(input_size);
69
70 if (check_type_shape())
71 {
72 _importer.readTensor(_curr_idx, input_idx, &buf.dtype, &buf.shape, buf.data.data(),
73 input_size);
74 }
75 else
76 {
77 _importer.readTensor(_curr_idx, input_idx, buf.data.data(), input_size);
78 }
79
80 res.emplace_back(buf);
81 }
82 }
83 catch (const H5::Exception &e)
84 {
85 H5::Exception::printErrorStack();
86 throw std::runtime_error("HDF5 error occurred during iteration.");
87 }
88
89 _curr_idx++; // move to the next index
90
91 return res;
92}
int32_t numInputs(int32_t data_idx) const
void readTensor(int32_t data_idx, int32_t input_idx, loco::DataType *dtype, std::vector< loco::Dimension > *shape, void *buffer, size_t buffer_bytes) const
Read tensor data from file and store it into buffer.
bool check_type_shape() const override
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
size_t getTensorSize(const luci::CircleNode *node)
Definition Utils.cpp:57
std::vector< char > DataBuffer

References check_type_shape(), record_minmax::getTensorSize(), dio::hdf5::HDF5Importer::numInputs(), and dio::hdf5::HDF5Importer::readTensor().


The documentation for this class was generated from the following files: