ONE - On-device Neural Engine
Loading...
Searching...
No Matches
dataloader.py
Go to the documentation of this file.
1import os
2import numpy as np
3from typing import List, Tuple, Union, Optional, Any, Iterator
4
5
7 """
8 A flexible DataLoader to manage training and validation data.
9 Automatically detects whether inputs are paths or NumPy arrays.
10 """
11 def __init__(self,
12 input_dataset: Union[List[np.ndarray], np.ndarray, str],
13 expected_dataset: Union[List[np.ndarray], np.ndarray, str],
14 batch_size: int,
15 input_shape: Optional[Tuple[int, ...]] = None,
16 expected_shape: Optional[Tuple[int, ...]] = None,
17 dtype: Any = np.float32) -> None:
18 """
19 Initialize the DataLoader.
20
21 Args:
22 input_dataset (list of np.ndarray | np.ndarray | str):
23 List of input arrays where each array's first dimension is the batch dimension,
24 or a single NumPy array, or a file path.
25 expected_dataset (list of np.ndarray | np.ndarray | str):
26 List of expected arrays where each array's first dimension is the batch dimension,
27 or a single NumPy array, or a file path.
28 batch_size (int): Number of samples per batch.
29 input_shape (tuple[int, ...], optional): Shape of the input data if raw format is used.
30 expected_shape (tuple[int, ...], optional): Shape of the expected data if raw format is used.
31 dtype (type, optional): Data type of the raw file (default: np.float32).
32 """
33 self.batch_size: int = batch_size
34 self.inputs: List[np.ndarray] = self._process_dataset(input_dataset, input_shape,
35 dtype)
36 self.expecteds: List[np.ndarray] = self._process_dataset(
37 expected_dataset, expected_shape, dtype)
38 self.batched_inputs: List[List[np.ndarray]] = []
39
40 # Verify data consistency
41 self.num_samples: int = self.inputs[0].shape[0] # Batch dimension
42 if self.num_samples != self.expecteds[0].shape[0]:
43 raise ValueError(
44 "Input data and expected data must have the same number of samples.")
45
46 # Precompute batches
48
50 data: Union[List[np.ndarray], np.ndarray, str],
51 shape: Optional[Tuple[int, ...]],
52 dtype: Any = np.float32) -> List[np.ndarray]:
53 """
54 Process a dataset or file path.
55
56 Args:
57 data (str | np.ndarray | list[np.ndarray]): Path to file or NumPy arrays.
58 shape (tuple[int, ...], optional): Shape of the data if raw format is used.
59 dtype (type, optional): Data type for raw files.
60
61 Returns:
62 list[np.ndarray]: Loaded or passed data as NumPy arrays.
63 """
64 if isinstance(data, list):
65 # Check if all elements in the list are NumPy arrays
66 if all(isinstance(item, np.ndarray) for item in data):
67 return data
68 raise ValueError("All elements in the list must be NumPy arrays.")
69 if isinstance(data, np.ndarray):
70 # If it's already a NumPy array and is not a list of arrays
71 if data.ndim > 1:
72 # If the array has multiple dimensions, split it into a list of arrays
73 return [data[i] for i in range(data.shape[0])]
74 else:
75 # If it's a single array, wrap it into a list
76 return [data]
77 elif isinstance(data, str):
78 # If it's a string, assume it's a file path
79 return [self._load_data(data, shape, dtype)]
80 else:
81 raise ValueError("Data must be a NumPy array or a valid file path.")
82
83 def _load_data(self,
84 file_path: str,
85 shape: Optional[Tuple[int, ...]],
86 dtype: Any = np.float32) -> np.ndarray:
87 """
88 Load data from a file, supporting both .npy and raw formats.
89
90 Args:
91 file_path (str): Path to the file to load.
92 shape (tuple[int, ...], optional): Shape of the data if raw format is used.
93 dtype (type, optional): Data type of the raw file (default: np.float32).
94
95 Returns:
96 np.ndarray: Loaded data as a NumPy array.
97 """
98 _, ext = os.path.splitext(file_path)
99
100 if ext == ".npy":
101 # Load .npy file
102 return np.load(file_path)
103 elif ext in [".bin", ".raw"]:
104 # Load raw binary file
105 if shape is None:
106 raise ValueError(f"Shape must be provided for raw file: {file_path}")
107 return self._load_raw(file_path, shape, dtype)
108 else:
109 raise ValueError(f"Unsupported file format: {ext}")
110
111 def _load_raw(self, file_path: str, shape: Tuple[int, ...], dtype: Any) -> np.ndarray:
112 """
113 Load raw binary data.
114
115 Args:
116 file_path (str): Path to the raw binary file.
117 shape (tuple[int, ...]): Shape of the data to reshape into.
118 dtype (type): Data type of the binary file.
119
120 Returns:
121 np.ndarray: Loaded data as a NumPy array.
122 """
123 # Calculate the expected number of elements based on the provided shape
124 expected_elements: int = int(np.prod(shape))
125
126 # Calculate the expected size of the raw file in bytes
127 expected_size: int = expected_elements * np.dtype(dtype).itemsize
128
129 # Get the actual size of the raw file
130 actual_size: int = os.path.getsize(file_path)
131
132 # Check if the sizes match
133 if actual_size != expected_size:
134 raise ValueError(
135 f"Raw file size ({actual_size} bytes) does not match the expected size "
136 f"({expected_size} bytes) based on the provided shape {shape} and dtype {dtype}."
137 )
138
139 # Read and load the raw data
140 with open(file_path, "rb") as f:
141 data = f.read()
142 array = np.frombuffer(data, dtype=dtype)
143 if array.size != expected_elements:
144 raise ValueError(
145 f"Raw data size does not match the expected shape: {shape}. "
146 f"Expected {expected_elements} elements, got {array.size} elements.")
147 return array.reshape(shape)
148
149 def _create_batches(self) -> Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]:
150 """
151 Precompute batches for inputs and expected outputs.
152
153 Returns:
154 tuple: Lists of batched inputs and batched expecteds.
155 """
156 batched_inputs: List[List[np.ndarray]] = []
157 batched_expecteds: List[List[np.ndarray]] = []
158
159 for batch_start in range(0, self.num_samples, self.batch_size):
160 batch_end = min(batch_start + self.batch_size, self.num_samples)
161
162 # Collect batched inputs
163 inputs_batch = [
164 input_array[batch_start:batch_end] for input_array in self.inputs
165 ]
166 if batch_end - batch_start < self.batch_size:
167 # Resize the last batch to match batch_size
168 inputs_batch = [
169 np.resize(batch, (self.batch_size, *batch.shape[1:]))
170 for batch in inputs_batch
171 ]
172
173 batched_inputs.append(inputs_batch)
174
175 # Collect batched expecteds
176 expecteds_batch = [
177 expected_array[batch_start:batch_end] for expected_array in self.expecteds
178 ]
179 if batch_end - batch_start < self.batch_size:
180 # Resize the last batch to match batch_size
181 expecteds_batch = [
182 np.resize(batch, (self.batch_size, *batch.shape[1:]))
183 for batch in expecteds_batch
184 ]
185
186 batched_expecteds.append(expecteds_batch)
187
188 return batched_inputs, batched_expecteds
189
190 def __iter__(self) -> Iterator[Tuple[List[np.ndarray], List[np.ndarray]]]:
191 """
192 Make the DataLoader iterable.
193
194 Returns:
195 self
196 """
197 self.index = 0
198 return self
199
200 def __next__(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
201 """
202 Return the next batch of data.
203
204 Returns:
205 tuple: (inputs, expecteds) for the next batch.
206 """
207 if self.index >= len(self.batched_inputs):
208 raise StopIteration
209
210 # Retrieve precomputed batch
211 input_batch = self.batched_inputs[self.index]
212 expected_batch = self.batched_expecteds[self.index]
213
214 self.index += 1
215 return input_batch, expected_batch
216
217 def split(self, validation_split: float) -> Tuple["DataLoader", "DataLoader"]:
218 """
219 Split the data into training and validation sets.
220
221 Args:
222 validation_split (float): Ratio of validation data. Must be between 0.0 and 1.0.
223
224 Returns:
225 tuple: Two DataLoader instances, one for training and one for validation.
226 """
227 if not (0.0 <= validation_split <= 1.0):
228 raise ValueError("Validation split must be between 0.0 and 1.0.")
229
230 split_index = int(len(self.inputs[0]) * (1.0 - validation_split))
231
232 train_inputs = [input_array[:split_index] for input_array in self.inputs]
233 val_inputs = [input_array[split_index:] for input_array in self.inputs]
234 train_expecteds = [
235 expected_array[:split_index] for expected_array in self.expecteds
236 ]
237 val_expecteds = [
238 expected_array[split_index:] for expected_array in self.expecteds
239 ]
240
241 train_loader = DataLoader(train_inputs, train_expecteds, self.batch_size)
242 val_loader = DataLoader(val_inputs, val_expecteds, self.batch_size)
243
244 return train_loader, val_loader
List[np.ndarray] _process_dataset(self, Union[List[np.ndarray], np.ndarray, str] data, Optional[Tuple[int,...]] shape, Any dtype=np.float32)
Definition dataloader.py:52
Tuple[List[List[np.ndarray]], List[List[np.ndarray]]] _create_batches(self)
np.ndarray _load_raw(self, str file_path, Tuple[int,...] shape, Any dtype)
Tuple[List[np.ndarray], List[np.ndarray]] __next__(self)
Iterator[Tuple[List[np.ndarray], List[np.ndarray]]] __iter__(self)
None __init__(self, Union[List[np.ndarray], np.ndarray, str] input_dataset, Union[List[np.ndarray], np.ndarray, str] expected_dataset, int batch_size, Optional[Tuple[int,...]] input_shape=None, Optional[Tuple[int,...]] expected_shape=None, Any dtype=np.float32)
Definition dataloader.py:17
np.ndarray _load_data(self, str file_path, Optional[Tuple[int,...]] shape, Any dtype=np.float32)
Definition dataloader.py:86
Tuple["DataLoader", "DataLoader"] split(self, float validation_split)