ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mpqsolver::core::H5FileDataProvider Class Referencefinal

#include <DataProvider.h>

Collaboration diagram for mpqsolver::core::H5FileDataProvider:

Public Member Functions

 H5FileDataProvider (const std::string &h5file, const std::string &module_path)
 
size_t numSamples () const override
 
uint32_t numInputs (uint32_t sample) const override
 
void getSampleInput (uint32_t sample, uint32_t input, InputData &data) const override
 
- Public Member Functions inherited from mpqsolver::core::DataProvider
virtual ~DataProvider ()=default
 

Detailed Description

Definition at line 53 of file DataProvider.h.

Constructor & Destructor Documentation

◆ H5FileDataProvider()

H5FileDataProvider::H5FileDataProvider ( const std::string &  h5file,
const std::string &  module_path 
)

Definition at line 52 of file DataProvider.cpp.

53 : _importer(h5file)
54{
55 _importer.importGroup("value");
56 _is_raw_data = _importer.isRawData();
57
58 luci::ImporterEx importerex;
59 _module = importerex.importVerifyModule(module_path);
60 if (_module.get() != nullptr)
61 {
62 _input_nodes = loco::input_nodes(_module.get()->graph());
63 }
64}
void importGroup(const std::string &group)
std::unique_ptr< Module > importVerifyModule(const std::string &input_path) const
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71

References dio::hdf5::HDF5Importer::importGroup(), luci::ImporterEx::importVerifyModule(), loco::input_nodes(), and dio::hdf5::HDF5Importer::isRawData().

Member Function Documentation

◆ getSampleInput()

void H5FileDataProvider::getSampleInput ( uint32_t  sample,
uint32_t  input,
InputData data 
) const
overridevirtual

Implements mpqsolver::core::DataProvider.

Definition at line 73 of file DataProvider.cpp.

74{
75 if (_is_raw_data)
76 {
77 _importer.readTensor(sample, input, data.data().data(), data.data().size());
78 }
79 else
80 {
82 Shape shape;
83 _importer.readTensor(sample, input, &dtype, &shape, data.data().data(), data.data().size());
84
85 // Check the type and the shape of the input data is valid
86 auto input_node = loco::must_cast<luci::CircleNode *>(_input_nodes.at(input));
87 verifyTypeShape(input_node, dtype, shape);
88 }
89}
void readTensor(int32_t data_idx, int32_t input_idx, loco::DataType *dtype, std::vector< loco::Dimension > *shape, void *buffer, size_t buffer_bytes) const
Read tensor data from file and store it into buffer.
void verifyTypeShape(const luci::CircleInput *input_node, const DataType &dtype, const Shape &shape)
DataType
"scalar" value type
Definition DataType.h:27
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
Definition Shape.h:28

References dio::hdf5::HDF5Importer::readTensor().

◆ numInputs()

uint32_t H5FileDataProvider::numInputs ( uint32_t  sample) const
overridevirtual

Implements mpqsolver::core::DataProvider.

Definition at line 68 of file DataProvider.cpp.

69{
70 return static_cast<uint32_t>(_importer.numInputs(sample));
71}
int32_t numInputs(int32_t data_idx) const

References dio::hdf5::HDF5Importer::numInputs().

◆ numSamples()

size_t H5FileDataProvider::numSamples ( ) const
overridevirtual

Implements mpqsolver::core::DataProvider.

Definition at line 66 of file DataProvider.cpp.

66{ return _importer.numData(); }
int32_t numData() const

References dio::hdf5::HDF5Importer::numData().


The documentation for this class was generated from the following files: