3from typing
import List, Tuple, Union, Optional, Any, Iterator
8 A flexible DataLoader to manage training and validation data.
9 Automatically detects whether inputs are paths or NumPy arrays.
12 input_dataset: Union[List[np.ndarray], np.ndarray, str],
13 expected_dataset: Union[List[np.ndarray], np.ndarray, str],
15 input_shape: Optional[Tuple[int, ...]] =
None,
16 expected_shape: Optional[Tuple[int, ...]] =
None,
17 dtype: Any = np.float32) ->
None:
19 Initialize the DataLoader.
22 input_dataset (list of np.ndarray | np.ndarray | str):
23 List of input arrays where each array's first dimension is the batch dimension,
24 or a single NumPy array, or a file path.
25 expected_dataset (list of np.ndarray | np.ndarray | str):
26 List of expected arrays where each array's first dimension is the batch dimension,
27 or a single NumPy array, or a file path.
28 batch_size (int): Number of samples per batch.
29 input_shape (tuple[int, ...], optional): Shape of the input data if raw format is used.
30 expected_shape (tuple[int, ...], optional): Shape of the expected data if raw format is used.
31 dtype (type, optional): Data type of the raw file (default: np.float32).
34 self.inputs: List[np.ndarray] = self.
_process_dataset(input_dataset, input_shape,
37 expected_dataset, expected_shape, dtype)
44 "Input data and expected data must have the same number of samples.")
50 data: Union[List[np.ndarray], np.ndarray, str],
51 shape: Optional[Tuple[int, ...]],
52 dtype: Any = np.float32) -> List[np.ndarray]:
54 Process a dataset or file path.
57 data (str | np.ndarray | list[np.ndarray]): Path to file or NumPy arrays.
58 shape (tuple[int, ...], optional): Shape of the data if raw format is used.
59 dtype (type, optional): Data type for raw files.
62 list[np.ndarray]: Loaded or passed data as NumPy arrays.
64 if isinstance(data, list):
66 if all(isinstance(item, np.ndarray)
for item
in data):
68 raise ValueError(
"All elements in the list must be NumPy arrays.")
69 if isinstance(data, np.ndarray):
73 return [data[i]
for i
in range(data.shape[0])]
77 elif isinstance(data, str):
81 raise ValueError(
"Data must be a NumPy array or a valid file path.")
85 shape: Optional[Tuple[int, ...]],
86 dtype: Any = np.float32) -> np.ndarray:
88 Load data from a file, supporting both .npy and raw formats.
91 file_path (str): Path to the file to load.
92 shape (tuple[int, ...], optional): Shape of the data if raw format is used.
93 dtype (type, optional): Data type of the raw file (default: np.float32).
96 np.ndarray: Loaded data as a NumPy array.
98 _, ext = os.path.splitext(file_path)
102 return np.load(file_path)
103 elif ext
in [
".bin",
".raw"]:
106 raise ValueError(f
"Shape must be provided for raw file: {file_path}")
107 return self.
_load_raw(file_path, shape, dtype)
109 raise ValueError(f
"Unsupported file format: {ext}")
111 def _load_raw(self, file_path: str, shape: Tuple[int, ...], dtype: Any) -> np.ndarray:
113 Load raw binary data.
116 file_path (str): Path to the raw binary file.
117 shape (tuple[int, ...]): Shape of the data to reshape into.
118 dtype (type): Data type of the binary file.
121 np.ndarray: Loaded data as a NumPy array.
124 expected_elements: int = int(np.prod(shape))
127 expected_size: int = expected_elements * np.dtype(dtype).itemsize
130 actual_size: int = os.path.getsize(file_path)
133 if actual_size != expected_size:
135 f
"Raw file size ({actual_size} bytes) does not match the expected size "
136 f
"({expected_size} bytes) based on the provided shape {shape} and dtype {dtype}."
140 with open(file_path,
"rb")
as f:
142 array = np.frombuffer(data, dtype=dtype)
143 if array.size != expected_elements:
145 f
"Raw data size does not match the expected shape: {shape}. "
146 f
"Expected {expected_elements} elements, got {array.size} elements.")
147 return array.reshape(shape)
151 Precompute batches for inputs and expected outputs.
154 tuple: Lists of batched inputs and batched expecteds.
156 batched_inputs: List[List[np.ndarray]] = []
157 batched_expecteds: List[List[np.ndarray]] = []
164 input_array[batch_start:batch_end]
for input_array
in self.inputs
169 np.resize(batch, (self.
batch_size, *batch.shape[1:]))
170 for batch
in inputs_batch
173 batched_inputs.append(inputs_batch)
177 expected_array[batch_start:batch_end]
for expected_array
in self.expecteds
182 np.resize(batch, (self.
batch_size, *batch.shape[1:]))
183 for batch
in expecteds_batch
186 batched_expecteds.append(expecteds_batch)
188 return batched_inputs, batched_expecteds
190 def __iter__(self) -> Iterator[Tuple[List[np.ndarray], List[np.ndarray]]]:
192 Make the DataLoader iterable.
200 def __next__(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
202 Return the next batch of data.
205 tuple: (inputs, expecteds) for the next batch.
215 return input_batch, expected_batch
217 def split(self, validation_split: float) -> Tuple[
"DataLoader",
"DataLoader"]:
219 Split the data into training and validation sets.
222 validation_split (float): Ratio of validation data. Must be between 0.0 and 1.0.
225 tuple: Two DataLoader instances, one for training and one for validation.
227 if not (0.0 <= validation_split <= 1.0):
228 raise ValueError(
"Validation split must be between 0.0 and 1.0.")
230 split_index = int(len(self.inputs[0]) * (1.0 - validation_split))
232 train_inputs = [input_array[:split_index]
for input_array
in self.inputs]
233 val_inputs = [input_array[split_index:]
for input_array
in self.inputs]
235 expected_array[:split_index]
for expected_array
in self.expecteds
238 expected_array[split_index:]
for expected_array
in self.expecteds
244 return train_loader, val_loader
List[np.ndarray] _process_dataset(self, Union[List[np.ndarray], np.ndarray, str] data, Optional[Tuple[int,...]] shape, Any dtype=np.float32)
Tuple[List[List[np.ndarray]], List[List[np.ndarray]]] _create_batches(self)
np.ndarray _load_raw(self, str file_path, Tuple[int,...] shape, Any dtype)
Tuple[List[np.ndarray], List[np.ndarray]] __next__(self)
Iterator[Tuple[List[np.ndarray], List[np.ndarray]]] __iter__(self)
None __init__(self, Union[List[np.ndarray], np.ndarray, str] input_dataset, Union[List[np.ndarray], np.ndarray, str] expected_dataset, int batch_size, Optional[Tuple[int,...]] input_shape=None, Optional[Tuple[int,...]] expected_shape=None, Any dtype=np.float32)
np.ndarray _load_data(self, str file_path, Optional[Tuple[int,...]] shape, Any dtype=np.float32)
Tuple["DataLoader", "DataLoader"] split(self, float validation_split)