22 def __init__(self, nnpackage_path: str, backends: str =
"train") ->
None:
24 Initialize the train session.
27 nnpackage_path (str): Path to the nnpackage file or directory.
28 backends (str): Backends to use, default is "train".
30 load_start: float = time.perf_counter()
32 libnnfw_api_pybind.experimental.nnfw_session(nnpackage_path, backends))
33 load_end: float = time.perf_counter()
35 self.total_time: Dict[str, Union[float, List[float]]] = {
36 'MODEL_LOAD': (load_end - load_start) * 1000
38 self.
train_info: traininfo = self.session.train_get_traininfo()
39 self.
optimizer: Optional[Optimizer] =
None
40 self.
loss: Optional[LossFunction] =
None
44 optimizer: Union[str, Optimizer],
45 loss: Union[str, LossFunction],
46 metrics: List[Union[str, Metric]] = [],
47 batch_size: int = 16) ->
None:
49 Compile the session with optimizer, loss, and metrics.
52 optimizer (str or Optimizer): Optimizer instance or name.
53 loss (str or LossFunction): Loss instance or name.
54 metrics (list of str or Metric): Metrics to evaluate during training.
55 batch_size (int): Number of samples per batch.
57 self.
optimizer = (OptimizerRegistry.create_optimizer(optimizer)
if isinstance(
58 optimizer, str)
else optimizer)
59 self.
loss = (LossRegistry.create_loss(loss)
if isinstance(loss, str)
else loss)
61 MetricsRegistry.create_metric(m)
if isinstance(m, str)
else m
for m
in metrics
66 if not isinstance(m, Metric):
67 raise TypeError(f
"Invalid metric type: {type(m).__name__}. "
68 "All metrics must inherit from the Metric base class.")
71 num_outputs: int = self.session.output_size()
72 if 0 < len(self.
metrics) != num_outputs:
74 f
"Number of metrics ({len(self.metrics)}) does not match outputs ({num_outputs})"
75 "Please ensure one metric is provided for each model output.")
80 self.
train_info.loss_info.loss = LossRegistry.map_loss_function_to_enum(self.
loss)
90 compile_start: float = time.perf_counter()
91 self.session.train_prepare()
92 compile_end: float = time.perf_counter()
93 self.total_time[
"COMPILE"] = (compile_end - compile_start) * 1000
115 data_loader: DataLoader,
117 validation_split: float = 0.0,
118 checkpoint_path: Optional[str] =
None
119 ) -> Dict[str, Union[float, List[float]]]:
121 Train the model using the given data loader.
124 data_loader: Data loader providing input and expected data.
125 epochs (int): Number of epochs to train.
126 validation_split (float): Ratio of validation data. Default is 0.0.
127 checkpoint_path (str, optional): Path to save training checkpoints.
130 dict: Timing and performance metrics.
133 raise RuntimeError(
"Call compile() before train().")
135 train_data, val_data = data_loader.split(validation_split)
136 epoch_times: List[float] = []
138 for epoch
in range(epochs):
139 epoch_start = time.perf_counter()
140 train_loss, io_ms, train_ms = self.
_run_phase(train_data, train=
True)
142 f
"Epoch {epoch+1}/{epochs}", f
"Train time: {train_ms:.3f}ms/step",
143 f
"IO time: {io_ms:.3f}ms/step", f
"Train Loss: {train_loss:.4f}"
146 if validation_split > 0.0:
147 val_loss, _, _ = self.
_run_phase(val_data, train=
False)
148 msg.append(f
"Validation Loss: {val_loss:.4f}")
150 msg.append(f
"{m.__class__.__name__}: {m.result():.4f}")
153 epoch_times.append((time.perf_counter() - epoch_start) * 1000)
154 print(
" - ".join(msg))
157 self.session.train_export_checkpoint(checkpoint_path)
159 self.total_time[
"EXECUTE"] = sum(epoch_times)
160 self.total_time[
"EPOCH_TIMES"] = epoch_times
161 return self.total_time
164 data: Tuple[List[np.ndarray], List[np.ndarray]],
165 train: bool =
True) -> Tuple[float, float, float]:
167 Run a training or validation phase.
170 data: Data generator.
171 train (bool): Whether to update weights.
174 (avg_loss, avg_io_ms, avg_train_ms)
176 total_loss: float = 0.0
179 train_time: float = 0.0
181 for inputs, expecteds
in data:
187 io_start = time.perf_counter()
188 for i, input_data
in enumerate(inputs):
189 self.session.train_set_input(i, input_data)
192 outputs: List[np.ndarray] = []
193 for i, expected_data
in enumerate(expecteds):
194 expected = np.array(expected_data,
195 dtype=self.session.output_tensorinfo(i).dtype)
196 self.session.train_set_expected(i, expected)
197 output = np.zeros(expected.shape,
198 dtype=self.session.output_tensorinfo(i).dtype)
199 self.session.train_set_output(i, output)
200 outputs.append(output)
201 io_end = time.perf_counter()
204 t_start = time.perf_counter()
205 self.session.train(update_weights=train)
206 t_end = time.perf_counter()
210 self.session.train_get_loss(i)
for i
in range(len(expecteds)))
211 total_loss += batch_loss
217 m.update_state(outputs, expecteds)
220 io_time += (io_end - io_start)
221 train_time += (t_end - t_start)
224 return (total_loss / num_batches, (io_time * 1000) / num_batches,
225 (train_time * 1000) / num_batches)
226 return (0.0, 0.0, 0.0)
250 expecteds: List[np.ndarray]) -> Dict[str, Any]:
252 Train the model for a single batch.
255 inputs (list of np.ndarray): List of input arrays for the batch.
256 expecteds (list of np.ndarray): List of expected output arrays for the batch.
259 dict: Loss and metrics values, and train_time in ms.
262 raise RuntimeError(
"Call `compile()` before `train_step()`")
263 result = self.
_batch_step(inputs, expecteds, update_weights=
True)
264 result[
"train_time"] = result.pop(
"time_ms")
268 expecteds: List[np.ndarray]) -> Dict[str, Any]:
270 Run one evaluation batch: forward only (no weight update).
273 inputs (list of np.ndarray): Inputs for this batch.
274 expecteds (list of np.ndarray): Ground‑truth outputs for this batch.
278 "loss": list of float losses,
279 "metrics": dict of metric_name -> value,
280 "eval_time": float (ms)
284 raise RuntimeError(
"Call `compile()` before `eval_step()`")
285 result = self.
_batch_step(inputs, expecteds, update_weights=
False)
286 result[
"eval_time"] = result.pop(
"time_ms")
289 def _batch_step(self, inputs: List[np.ndarray], expecteds: List[np.ndarray],
290 update_weights: bool) -> Dict[str, Any]:
292 Common logic for one batch: bind data, run, collect loss & metrics.
293 Returns a dict with keys "loss", "metrics", "time_ms".
300 for i, input_data
in enumerate(inputs):
301 self.session.train_set_input(i, input_data)
304 outputs: List[np.ndarray] = []
305 for i, expected_data
in enumerate(expecteds):
306 self.session.train_set_expected(i, expected_data)
307 output = np.zeros(expected_data.shape,
308 dtype=self.session.output_tensorinfo(i).dtype)
309 self.session.train_set_output(i, output)
310 outputs.append(output)
313 t_start: float = time.perf_counter()
314 self.session.train(update_weights=update_weights)
315 t_end: float = time.perf_counter()
318 losses = [self.session.train_get_loss(i)
for i
in range(len(expecteds))]
321 metrics: Dict[str, float] = {}
323 m.update_state(outputs, expecteds)
324 metrics[m.__class__.__name__] = m.result()
327 return {
"loss": losses,
"metrics": metrics,
"time_ms": (t_end - t_start) * 1000}