ONE - On-device Neural Engine
Loading...
Searching...
No Matches
session.py
Go to the documentation of this file.
1import time
2import numpy as np
3from typing import Any, List, Tuple, Dict, Union, Optional
4
5from onert.native import libnnfw_api_pybind
6from onert.native.libnnfw_api_pybind import traininfo
7from onert.common.basesession import BaseSession
8from .dataloader import DataLoader
9from .losses.loss import LossFunction
10from .losses.registry import LossRegistry
11from .metrics.metric import Metric
12from .metrics.registry import MetricsRegistry
13from .optimizer.optimizer import Optimizer
14from .optimizer.registry import OptimizerRegistry
15
16
17# TODO: Support import checkpoint
18class TrainSession(BaseSession):
19 """
20 Class for training and inference using nnfw_session.
21 """
22 def __init__(self, nnpackage_path: str, backends: str = "train") -> None:
23 """
24 Initialize the train session.
25
26 Args:
27 nnpackage_path (str): Path to the nnpackage file or directory.
28 backends (str): Backends to use, default is "train".
29 """
30 load_start: float = time.perf_counter()
31 super().__init__(
32 libnnfw_api_pybind.experimental.nnfw_session(nnpackage_path, backends))
33 load_end: float = time.perf_counter()
34
35 self.total_time: Dict[str, Union[float, List[float]]] = {
36 'MODEL_LOAD': (load_end - load_start) * 1000
37 }
38 self.train_info: traininfo = self.session.train_get_traininfo()
39 self.optimizer: Optional[Optimizer] = None
40 self.loss: Optional[LossFunction] = None
41 self.metrics: List[Metric] = []
42
43 def compile(self,
44 optimizer: Union[str, Optimizer],
45 loss: Union[str, LossFunction],
46 metrics: List[Union[str, Metric]] = [],
47 batch_size: int = 16) -> None:
48 """
49 Compile the session with optimizer, loss, and metrics.
50
51 Args:
52 optimizer (str or Optimizer): Optimizer instance or name.
53 loss (str or LossFunction): Loss instance or name.
54 metrics (list of str or Metric): Metrics to evaluate during training.
55 batch_size (int): Number of samples per batch.
56 """
57 self.optimizer = (OptimizerRegistry.create_optimizer(optimizer) if isinstance(
58 optimizer, str) else optimizer)
59 self.loss = (LossRegistry.create_loss(loss) if isinstance(loss, str) else loss)
60 self.metrics = [
61 MetricsRegistry.create_metric(m) if isinstance(m, str) else m for m in metrics
62 ]
63
64 # Validate that all elements in self.metrics are instances of Metric
65 for m in self.metrics:
66 if not isinstance(m, Metric):
67 raise TypeError(f"Invalid metric type: {type(m).__name__}. "
68 "All metrics must inherit from the Metric base class.")
69
70 # Check if the number of metrics matches the number of outputs
71 num_outputs: int = self.session.output_size()
72 if 0 < len(self.metrics) != num_outputs:
73 raise ValueError(
74 f"Number of metrics ({len(self.metrics)}) does not match outputs ({num_outputs})"
75 "Please ensure one metric is provided for each model output.")
76
77 # Set training info
78 self.train_info.learning_rate = self.optimizer.learning_rate
79 self.train_info.batch_size = batch_size
80 self.train_info.loss_info.loss = LossRegistry.map_loss_function_to_enum(self.loss)
81 self.train_info.loss_info.reduction_type = self.loss.reduction
82 self.train_info.opt = OptimizerRegistry.map_optimizer_to_enum(self.optimizer)
83 self.train_info.num_of_trainable_ops = self.optimizer.nums_trainable_ops
84 self.session.train_set_traininfo(self.train_info)
85
86 # Print training parameters
88
89 # Prepare session for training
90 compile_start: float = time.perf_counter()
91 self.session.train_prepare()
92 compile_end: float = time.perf_counter()
93 self.total_time["COMPILE"] = (compile_end - compile_start) * 1000
94
95 def _print_training_parameters(self) -> None:
96 """
97 Print the training parameters in a formatted way.
98 """
99 loss_name: str = self.loss.__class__.__name__ if self.loss else "Unknown Loss"
100 reduction_name: str = (
101 self.train_info.loss_info.reduction_type.name.lower().replace("_", " "))
102 opt_name: str = self.optimizer.__class__.__name__ if self.optimizer else "Unknown Optimizer"
103
104 print("== training parameter ==")
105 print(f"- learning_rate = {self.train_info.learning_rate:.4f}".rstrip('0').rstrip(
106 '.'))
107 print(f"- batch_size = {self.train_info.batch_size}")
108 print(f"- loss_info = {{loss = {loss_name}, reduction = {reduction_name}}}")
109 print(f"- optimizer = {opt_name}")
110 print(f"- num_of_trainable_ops = {self.train_info.num_of_trainable_ops}")
111 print("========================")
112
113 def train(
114 self,
115 data_loader: DataLoader,
116 epochs: int,
117 validation_split: float = 0.0,
118 checkpoint_path: Optional[str] = None
119 ) -> Dict[str, Union[float, List[float]]]:
120 """
121 Train the model using the given data loader.
122
123 Args:
124 data_loader: Data loader providing input and expected data.
125 epochs (int): Number of epochs to train.
126 validation_split (float): Ratio of validation data. Default is 0.0.
127 checkpoint_path (str, optional): Path to save training checkpoints.
128
129 Returns:
130 dict: Timing and performance metrics.
131 """
132 if self.optimizer is None or self.loss is None:
133 raise RuntimeError("Call compile() before train().")
134
135 train_data, val_data = data_loader.split(validation_split)
136 epoch_times: List[float] = []
137
138 for epoch in range(epochs):
139 epoch_start = time.perf_counter()
140 train_loss, io_ms, train_ms = self._run_phase(train_data, train=True)
141 msg = [
142 f"Epoch {epoch+1}/{epochs}", f"Train time: {train_ms:.3f}ms/step",
143 f"IO time: {io_ms:.3f}ms/step", f"Train Loss: {train_loss:.4f}"
144 ]
145
146 if validation_split > 0.0:
147 val_loss, _, _ = self._run_phase(val_data, train=False)
148 msg.append(f"Validation Loss: {val_loss:.4f}")
149 for m in self.metrics:
150 msg.append(f"{m.__class__.__name__}: {m.result():.4f}")
151 m.reset_state()
152
153 epoch_times.append((time.perf_counter() - epoch_start) * 1000)
154 print(" - ".join(msg))
155
156 if checkpoint_path:
157 self.session.train_export_checkpoint(checkpoint_path)
158
159 self.total_time["EXECUTE"] = sum(epoch_times)
160 self.total_time["EPOCH_TIMES"] = epoch_times
161 return self.total_time
162
163 def _run_phase(self,
164 data: Tuple[List[np.ndarray], List[np.ndarray]],
165 train: bool = True) -> Tuple[float, float, float]:
166 """
167 Run a training or validation phase.
168
169 Args:
170 data: Data generator.
171 train (bool): Whether to update weights.
172
173 Returns:
174 (avg_loss, avg_io_ms, avg_train_ms)
175 """
176 total_loss: float = 0.0
177 num_batches: int = 0
178 io_time: float = 0.0
179 train_time: float = 0.0
180
181 for inputs, expecteds in data:
182 # Validate batch sizes
183 self._check_batch_size(inputs, self.train_info.batch_size, "input")
184 self._check_batch_size(expecteds, self.train_info.batch_size, "expected")
185
186 # Set inputs
187 io_start = time.perf_counter()
188 for i, input_data in enumerate(inputs):
189 self.session.train_set_input(i, input_data)
190
191 # Set expected outputs
192 outputs: List[np.ndarray] = []
193 for i, expected_data in enumerate(expecteds):
194 expected = np.array(expected_data,
195 dtype=self.session.output_tensorinfo(i).dtype)
196 self.session.train_set_expected(i, expected)
197 output = np.zeros(expected.shape,
198 dtype=self.session.output_tensorinfo(i).dtype)
199 self.session.train_set_output(i, output)
200 outputs.append(output)
201 io_end = time.perf_counter()
202
203 # Run training or validation
204 t_start = time.perf_counter()
205 self.session.train(update_weights=train)
206 t_end = time.perf_counter()
207
208 # Accumulate loss
209 batch_loss = sum(
210 self.session.train_get_loss(i) for i in range(len(expecteds)))
211 total_loss += batch_loss
212 num_batches += 1
213
214 # Update metrics
215 if not train:
216 for m in self.metrics:
217 m.update_state(outputs, expecteds)
218
219 # Calculate times
220 io_time += (io_end - io_start)
221 train_time += (t_end - t_start)
222
223 if num_batches:
224 return (total_loss / num_batches, (io_time * 1000) / num_batches,
225 (train_time * 1000) / num_batches)
226 return (0.0, 0.0, 0.0)
227
229 data: List[np.ndarray],
230 batch_size: int,
231 data_type: str = "input") -> None:
232 """
233 Validate that the batch size of the data matches the configured training batch size.
234
235 Args:
236 data (list of np.ndarray): The data to validate.
237 batch_size (int): The expected batch size.
238 data_type (str): 'input' or 'expected'.
239
240 Raises:
241 ValueError: If the batch size does not match the expected value.
242 """
243 for idx, arr in enumerate(data):
244 if arr.shape[0] > batch_size:
245 raise ValueError(
246 f"{data_type} batch size mismatch at index {idx}: "
247 f"shape[0] = {arr.shape[0]} vs batch size = {batch_size}")
248
249 def train_step(self, inputs: List[np.ndarray],
250 expecteds: List[np.ndarray]) -> Dict[str, Any]:
251 """
252 Train the model for a single batch.
253
254 Args:
255 inputs (list of np.ndarray): List of input arrays for the batch.
256 expecteds (list of np.ndarray): List of expected output arrays for the batch.
257
258 Returns:
259 dict: Loss and metrics values, and train_time in ms.
260 """
261 if self.optimizer is None or self.loss is None:
262 raise RuntimeError("Call `compile()` before `train_step()`")
263 result = self._batch_step(inputs, expecteds, update_weights=True)
264 result["train_time"] = result.pop("time_ms")
265 return result
266
267 def eval_step(self, inputs: List[np.ndarray],
268 expecteds: List[np.ndarray]) -> Dict[str, Any]:
269 """
270 Run one evaluation batch: forward only (no weight update).
271
272 Args:
273 inputs (list of np.ndarray): Inputs for this batch.
274 expecteds (list of np.ndarray): Ground‑truth outputs for this batch.
275
276 Returns:
277 dict: {
278 "loss": list of float losses,
279 "metrics": dict of metric_name -> value,
280 "eval_time": float (ms)
281 }
282 """
283 if self.optimizer is None or self.loss is None:
284 raise RuntimeError("Call `compile()` before `eval_step()`")
285 result = self._batch_step(inputs, expecteds, update_weights=False)
286 result["eval_time"] = result.pop("time_ms")
287 return result
288
289 def _batch_step(self, inputs: List[np.ndarray], expecteds: List[np.ndarray],
290 update_weights: bool) -> Dict[str, Any]:
291 """
292 Common logic for one batch: bind data, run, collect loss & metrics.
293 Returns a dict with keys "loss", "metrics", "time_ms".
294 """
295 # Validate batch sizes
296 self._check_batch_size(inputs, self.train_info.batch_size, "input")
297 self._check_batch_size(expecteds, self.train_info.batch_size, "expected")
298
299 # Set inputs
300 for i, input_data in enumerate(inputs):
301 self.session.train_set_input(i, input_data)
302
303 # Set expected outputs
304 outputs: List[np.ndarray] = []
305 for i, expected_data in enumerate(expecteds):
306 self.session.train_set_expected(i, expected_data)
307 output = np.zeros(expected_data.shape,
308 dtype=self.session.output_tensorinfo(i).dtype)
309 self.session.train_set_output(i, output)
310 outputs.append(output)
311
312 # Run a single training step
313 t_start: float = time.perf_counter()
314 self.session.train(update_weights=update_weights)
315 t_end: float = time.perf_counter()
316
317 # Update loss
318 losses = [self.session.train_get_loss(i) for i in range(len(expecteds))]
319
320 # Update metrics
321 metrics: Dict[str, float] = {}
322 for m in self.metrics:
323 m.update_state(outputs, expecteds)
324 metrics[m.__class__.__name__] = m.result()
325 m.reset_state()
326
327 return {"loss": losses, "metrics": metrics, "time_ms": (t_end - t_start) * 1000}
None _check_batch_size(self, List[np.ndarray] data, int batch_size, str data_type="input")
Definition session.py:231
None __init__(self, str nnpackage_path, str backends="train")
Definition session.py:22
Dict[str, Any] _batch_step(self, List[np.ndarray] inputs, List[np.ndarray] expecteds, bool update_weights)
Definition session.py:290
Dict[str, Any] train_step(self, List[np.ndarray] inputs, List[np.ndarray] expecteds)
Definition session.py:250
Tuple[float, float, float] _run_phase(self, Tuple[List[np.ndarray], List[np.ndarray]] data, bool train=True)
Definition session.py:165
Dict[str, Any] eval_step(self, List[np.ndarray] inputs, List[np.ndarray] expecteds)
Definition session.py:268
None compile(self, Union[str, Optimizer] optimizer, Union[str, LossFunction] loss, List[Union[str, Metric]] metrics=[], int batch_size=16)
Definition session.py:47