ONE - On-device Neural Engine
Loading...
Searching...
No Matches
record_minmax::RandomIterator Class Referencefinal

#include <RandomIterator.h>

Collaboration diagram for record_minmax::RandomIterator:

Public Member Functions

 RandomIterator (luci::Module *module)
 
bool hasNext () const override
 
std::vector< DataBuffernext () override
 
bool check_type_shape () const override
 
- Public Member Functions inherited from record_minmax::DataSetIterator
virtual ~DataSetIterator ()=default
 

Detailed Description

Definition at line 32 of file RandomIterator.h.

Constructor & Destructor Documentation

◆ RandomIterator()

record_minmax::RandomIterator::RandomIterator ( luci::Module module)

Definition at line 64 of file RandomIterator.cpp.

65{
66 assert(module); // FIX_CALLER_UNLESS
67
68 std::random_device rd;
69 std::mt19937 _gen(rd());
70
71 auto input_nodes = loco::input_nodes(module->graph());
72 for (auto input_node : input_nodes)
73 {
74 const auto cnode = loco::must_cast<const luci::CircleInput *>(input_node);
75 _input_nodes.emplace_back(cnode);
76 }
77
78 // Hardcoded
79 _num_data = 3;
80}
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71

References loco::input_nodes().

Member Function Documentation

◆ check_type_shape()

bool record_minmax::RandomIterator::check_type_shape ( ) const
overridevirtual

Implements record_minmax::DataSetIterator.

Definition at line 143 of file RandomIterator.cpp.

143{ return false; }

◆ hasNext()

bool record_minmax::RandomIterator::hasNext ( ) const
overridevirtual

Implements record_minmax::DataSetIterator.

Definition at line 82 of file RandomIterator.cpp.

82{ return _curr_idx < _num_data; }

◆ next()

std::vector< DataBuffer > record_minmax::RandomIterator::next ( )
overridevirtual

Implements record_minmax::DataSetIterator.

Definition at line 84 of file RandomIterator.cpp.

85{
86 std::vector<DataBuffer> res;
87
88 for (auto input_node : _input_nodes)
89 {
91
92 const auto dtype = input_node->dtype();
93 const auto num_elements = numElements(input_node);
94
95 buf.data.resize(getTensorSize(input_node));
96
97 switch (dtype)
98 {
99 case loco::DataType::FLOAT32:
100 {
101 const auto input_data = genRandomData(_gen, num_elements, -5, 5);
102 const auto data_size = input_data.size() * sizeof(float);
103 assert(buf.data.size() == data_size);
104 memcpy(buf.data.data(), input_data.data(), data_size);
105 break;
106 }
107 case loco::DataType::S32:
108 {
109 const auto input_data = genRandomIntData<int32_t>(_gen, num_elements, 0, 100);
110 const auto data_size = input_data.size() * sizeof(int32_t);
111 assert(buf.data.size() == data_size);
112 memcpy(buf.data.data(), input_data.data(), data_size);
113 break;
114 }
115 case loco::DataType::S64:
116 {
117 const auto input_data = genRandomIntData<int64_t>(_gen, num_elements, 0, 100);
118 const auto data_size = input_data.size() * sizeof(int64_t);
119 assert(buf.data.size() == data_size);
120 memcpy(buf.data.data(), input_data.data(), data_size);
121 break;
122 }
123 case loco::DataType::BOOL:
124 {
125 const auto input_data = genRandomIntData<uint8_t>(_gen, num_elements, 0, 1);
126 const auto data_size = input_data.size() * sizeof(uint8_t);
127 assert(buf.data.size() == data_size);
128 memcpy(buf.data.data(), input_data.data(), data_size);
129 break;
130 }
131 default:
132 throw std::runtime_error("Unsupported datatype");
133 }
134
135 res.emplace_back(buf);
136 }
137
138 _curr_idx++; // move to the next index
139
140 return res;
141}
list input_data
Definition infer.py:29
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59
size_t getTensorSize(const luci::CircleNode *node)
Definition Utils.cpp:57
uint32_t numElements(const luci::CircleNode *node)
Definition Utils.cpp:41
std::vector< char > DataBuffer

References record_minmax::getTensorSize(), and record_minmax::numElements().


The documentation for this class was generated from the following files: