ONE - On-device Neural Engine
Loading...
Searching...
No Matches
package.experimental.train.session.TrainSession Class Reference
Collaboration diagram for package.experimental.train.session.TrainSession:

Public Member Functions

None __init__ (self, str nnpackage_path, str backends="train")
 
None compile (self, Union[str, Optimizer] optimizer, Union[str, LossFunction] loss, List[Union[str, Metric]] metrics=[], int batch_size=16)
 
Dict[str, Union[float, List[float]]] train (self, DataLoader data_loader, int epochs, float validation_split=0.0, Optional[str] checkpoint_path=None)
 
Dict[str, Any] train_step (self, List[np.ndarray] inputs, List[np.ndarray] expecteds)
 
Dict[str, Any] eval_step (self, List[np.ndarray] inputs, List[np.ndarray] expecteds)
 

Data Fields

 optimizer
 
 loss
 
 metrics
 
 train_info
 

Protected Member Functions

None _print_training_parameters (self)
 
Tuple[float, float, float] _run_phase (self, Tuple[List[np.ndarray], List[np.ndarray]] data, bool train=True)
 
None _check_batch_size (self, List[np.ndarray] data, int batch_size, str data_type="input")
 
Dict[str, Any] _batch_step (self, List[np.ndarray] inputs, List[np.ndarray] expecteds, bool update_weights)
 

Detailed Description

Class for training and inference using nnfw_session.

Definition at line 18 of file session.py.

Constructor & Destructor Documentation

◆ __init__()

None package.experimental.train.session.TrainSession.__init__ (   self,
str  nnpackage_path,
str   backends = "train" 
)
Initialize the train session.

Args:
    nnpackage_path (str): Path to the nnpackage file or directory.
    backends (str): Backends to use, default is "train".

Definition at line 22 of file session.py.

22 def __init__(self, nnpackage_path: str, backends: str = "train") -> None:
23 """
24 Initialize the train session.
25
26 Args:
27 nnpackage_path (str): Path to the nnpackage file or directory.
28 backends (str): Backends to use, default is "train".
29 """
30 load_start: float = time.perf_counter()
31 super().__init__(
32 libnnfw_api_pybind.experimental.nnfw_session(nnpackage_path, backends))
33 load_end: float = time.perf_counter()
34
35 self.total_time: Dict[str, Union[float, List[float]]] = {
36 'MODEL_LOAD': (load_end - load_start) * 1000
37 }
38 self.train_info: traininfo = self.session.train_get_traininfo()
39 self.optimizer: Optional[Optimizer] = None
40 self.loss: Optional[LossFunction] = None
41 self.metrics: List[Metric] = []
42

References package.experimental.train.session.TrainSession.__init__(), onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, package.experimental.train.session.TrainSession.loss, package.experimental.train.session.TrainSession.metrics, onert_micro::OMTrainingContext.optimizer, package.experimental.train.session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, and package.experimental.train.session.TrainSession.train_info.

Referenced by package.experimental.train.session.TrainSession.__init__().

Member Function Documentation

◆ _batch_step()

Dict[str, Any] package.experimental.train.session.TrainSession._batch_step (   self,
List[np.ndarray]  inputs,
List[np.ndarray]  expecteds,
bool  update_weights 
)
protected
Common logic for one batch: bind data, run, collect loss & metrics.
Returns a dict with keys "loss", "metrics", "time_ms".

Definition at line 289 of file session.py.

290 update_weights: bool) -> Dict[str, Any]:
291 """
292 Common logic for one batch: bind data, run, collect loss & metrics.
293 Returns a dict with keys "loss", "metrics", "time_ms".
294 """
295 # Validate batch sizes
296 self._check_batch_size(inputs, self.train_info.batch_size, "input")
297 self._check_batch_size(expecteds, self.train_info.batch_size, "expected")
298
299 # Set inputs
300 for i, input_data in enumerate(inputs):
301 self.session.train_set_input(i, input_data)
302
303 # Set expected outputs
304 outputs: List[np.ndarray] = []
305 for i, expected_data in enumerate(expecteds):
306 self.session.train_set_expected(i, expected_data)
307 output = np.zeros(expected_data.shape,
308 dtype=self.session.output_tensorinfo(i).dtype)
309 self.session.train_set_output(i, output)
310 outputs.append(output)
311
312 # Run a single training step
313 t_start: float = time.perf_counter()
314 self.session.train(update_weights=update_weights)
315 t_end: float = time.perf_counter()
316
317 # Update loss
318 losses = [self.session.train_get_loss(i) for i in range(len(expecteds))]
319
320 # Update metrics
321 metrics: Dict[str, float] = {}
322 for m in self.metrics:
323 m.update_state(outputs, expecteds)
324 metrics[m.__class__.__name__] = m.result()
325 m.reset_state()
326
327 return {"loss": losses, "metrics": metrics, "time_ms": (t_end - t_start) * 1000}

References package.experimental.train.session.TrainSession._check_batch_size(), package.experimental.train.session.TrainSession.metrics, validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, and package.experimental.train.session.TrainSession.train_info.

Referenced by package.experimental.train.session.TrainSession.eval_step(), and package.experimental.train.session.TrainSession.train_step().

◆ _check_batch_size()

None package.experimental.train.session.TrainSession._check_batch_size (   self,
List[np.ndarray]  data,
int  batch_size,
str   data_type = "input" 
)
protected
Validate that the batch size of the data matches the configured training batch size.

Args:
    data (list of np.ndarray): The data to validate.
    batch_size (int): The expected batch size.
    data_type (str): 'input' or 'expected'.

Raises:
    ValueError: If the batch size does not match the expected value.

Definition at line 228 of file session.py.

231 data_type: str = "input") -> None:
232 """
233 Validate that the batch size of the data matches the configured training batch size.
234
235 Args:
236 data (list of np.ndarray): The data to validate.
237 batch_size (int): The expected batch size.
238 data_type (str): 'input' or 'expected'.
239
240 Raises:
241 ValueError: If the batch size does not match the expected value.
242 """
243 for idx, arr in enumerate(data):
244 if arr.shape[0] > batch_size:
245 raise ValueError(
246 f"{data_type} batch size mismatch at index {idx}: "
247 f"shape[0] = {arr.shape[0]} vs batch size = {batch_size}")
248

Referenced by package.experimental.train.session.TrainSession._batch_step(), and package.experimental.train.session.TrainSession._run_phase().

◆ _print_training_parameters()

None package.experimental.train.session.TrainSession._print_training_parameters (   self)
protected
Print the training parameters in a formatted way.

Definition at line 95 of file session.py.

95 def _print_training_parameters(self) -> None:
96 """
97 Print the training parameters in a formatted way.
98 """
99 loss_name: str = self.loss.__class__.__name__ if self.loss else "Unknown Loss"
100 reduction_name: str = (
101 self.train_info.loss_info.reduction_type.name.lower().replace("_", " "))
102 opt_name: str = self.optimizer.__class__.__name__ if self.optimizer else "Unknown Optimizer"
103
104 print("== training parameter ==")
105 print(f"- learning_rate = {self.train_info.learning_rate:.4f}".rstrip('0').rstrip(
106 '.'))
107 print(f"- batch_size = {self.train_info.batch_size}")
108 print(f"- loss_info = {{loss = {loss_name}, reduction = {reduction_name}}}")
109 print(f"- optimizer = {opt_name}")
110 print(f"- num_of_trainable_ops = {self.train_info.num_of_trainable_ops}")
111 print("========================")
112

References onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, package.experimental.train.session.TrainSession.loss, onert_micro::OMTrainingContext.optimizer, package.experimental.train.session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), and package.experimental.train.session.TrainSession.train_info.

◆ _run_phase()

Tuple[float, float, float] package.experimental.train.session.TrainSession._run_phase (   self,
Tuple[List[np.ndarray], List[np.ndarray]]  data,
bool   train = True 
)
protected
Run a training or validation phase.

Args:
    data: Data generator.
    train (bool): Whether to update weights.

Returns:
    (avg_loss, avg_io_ms, avg_train_ms)

Definition at line 163 of file session.py.

165 train: bool = True) -> Tuple[float, float, float]:
166 """
167 Run a training or validation phase.
168
169 Args:
170 data: Data generator.
171 train (bool): Whether to update weights.
172
173 Returns:
174 (avg_loss, avg_io_ms, avg_train_ms)
175 """
176 total_loss: float = 0.0
177 num_batches: int = 0
178 io_time: float = 0.0
179 train_time: float = 0.0
180
181 for inputs, expecteds in data:
182 # Validate batch sizes
183 self._check_batch_size(inputs, self.train_info.batch_size, "input")
184 self._check_batch_size(expecteds, self.train_info.batch_size, "expected")
185
186 # Set inputs
187 io_start = time.perf_counter()
188 for i, input_data in enumerate(inputs):
189 self.session.train_set_input(i, input_data)
190
191 # Set expected outputs
192 outputs: List[np.ndarray] = []
193 for i, expected_data in enumerate(expecteds):
194 expected = np.array(expected_data,
195 dtype=self.session.output_tensorinfo(i).dtype)
196 self.session.train_set_expected(i, expected)
197 output = np.zeros(expected.shape,
198 dtype=self.session.output_tensorinfo(i).dtype)
199 self.session.train_set_output(i, output)
200 outputs.append(output)
201 io_end = time.perf_counter()
202
203 # Run training or validation
204 t_start = time.perf_counter()
205 self.session.train(update_weights=train)
206 t_end = time.perf_counter()
207
208 # Accumulate loss
209 batch_loss = sum(
210 self.session.train_get_loss(i) for i in range(len(expecteds)))
211 total_loss += batch_loss
212 num_batches += 1
213
214 # Update metrics
215 if not train:
216 for m in self.metrics:
217 m.update_state(outputs, expecteds)
218
219 # Calculate times
220 io_time += (io_end - io_start)
221 train_time += (t_end - t_start)
222
223 if num_batches:
224 return (total_loss / num_batches, (io_time * 1000) / num_batches,
225 (train_time * 1000) / num_batches)
226 return (0.0, 0.0, 0.0)
227

References package.experimental.train.session.TrainSession._check_batch_size(), package.experimental.train.session.TrainSession.metrics, validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, and package.experimental.train.session.TrainSession.train_info.

Referenced by package.experimental.train.session.TrainSession.train().

◆ compile()

None package.experimental.train.session.TrainSession.compile (   self,
Union[str, Optimizer optimizer,
Union[str, LossFunction loss,
List[Union[str, Metric]]   metrics = [],
int   batch_size = 16 
)
Compile the session with optimizer, loss, and metrics.

Args:
    optimizer (str or Optimizer): Optimizer instance or name.
    loss (str or LossFunction): Loss instance or name.
    metrics (list of str or Metric): Metrics to evaluate during training.
    batch_size (int): Number of samples per batch.

Definition at line 43 of file session.py.

47 batch_size: int = 16) -> None:
48 """
49 Compile the session with optimizer, loss, and metrics.
50
51 Args:
52 optimizer (str or Optimizer): Optimizer instance or name.
53 loss (str or LossFunction): Loss instance or name.
54 metrics (list of str or Metric): Metrics to evaluate during training.
55 batch_size (int): Number of samples per batch.
56 """
57 self.optimizer = (OptimizerRegistry.create_optimizer(optimizer) if isinstance(
58 optimizer, str) else optimizer)
59 self.loss = (LossRegistry.create_loss(loss) if isinstance(loss, str) else loss)
60 self.metrics = [
61 MetricsRegistry.create_metric(m) if isinstance(m, str) else m for m in metrics
62 ]
63
64 # Validate that all elements in self.metrics are instances of Metric
65 for m in self.metrics:
66 if not isinstance(m, Metric):
67 raise TypeError(f"Invalid metric type: {type(m).__name__}. "
68 "All metrics must inherit from the Metric base class.")
69
70 # Check if the number of metrics matches the number of outputs
71 num_outputs: int = self.session.output_size()
72 if 0 < len(self.metrics) != num_outputs:
73 raise ValueError(
74 f"Number of metrics ({len(self.metrics)}) does not match outputs ({num_outputs})"
75 "Please ensure one metric is provided for each model output.")
76
77 # Set training info
78 self.train_info.learning_rate = self.optimizer.learning_rate
79 self.train_info.batch_size = batch_size
80 self.train_info.loss_info.loss = LossRegistry.map_loss_function_to_enum(self.loss)
81 self.train_info.loss_info.reduction_type = self.loss.reduction
82 self.train_info.opt = OptimizerRegistry.map_optimizer_to_enum(self.optimizer)
83 self.train_info.num_of_trainable_ops = self.optimizer.nums_trainable_ops
84 self.session.train_set_traininfo(self.train_info)
85
86 # Print training parameters
87 self._print_training_parameters()
88
89 # Prepare session for training
90 compile_start: float = time.perf_counter()
91 self.session.train_prepare()
92 compile_end: float = time.perf_counter()
93 self.total_time["COMPILE"] = (compile_end - compile_start) * 1000
94

◆ eval_step()

Dict[str, Any] package.experimental.train.session.TrainSession.eval_step (   self,
List[np.ndarray]  inputs,
List[np.ndarray]  expecteds 
)
Run one evaluation batch: forward only (no weight update).

Args:
    inputs (list of np.ndarray): Inputs for this batch.
    expecteds (list of np.ndarray): Ground‑truth outputs for this batch.

Returns:
    dict: {
        "loss": list of float losses,
        "metrics": dict of metric_name -> value,
        "eval_time": float (ms)
    }

Definition at line 267 of file session.py.

268 expecteds: List[np.ndarray]) -> Dict[str, Any]:
269 """
270 Run one evaluation batch: forward only (no weight update).
271
272 Args:
273 inputs (list of np.ndarray): Inputs for this batch.
274 expecteds (list of np.ndarray): Ground‑truth outputs for this batch.
275
276 Returns:
277 dict: {
278 "loss": list of float losses,
279 "metrics": dict of metric_name -> value,
280 "eval_time": float (ms)
281 }
282 """
283 if self.optimizer is None or self.loss is None:
284 raise RuntimeError("Call `compile()` before `eval_step()`")
285 result = self._batch_step(inputs, expecteds, update_weights=False)
286 result["eval_time"] = result.pop("time_ms")
287 return result
288

References package.experimental.train.session.TrainSession._batch_step(), onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, package.experimental.train.session.TrainSession.loss, onert_micro::OMTrainingContext.optimizer, package.experimental.train.session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, and onert::backend::train::BackendContext.optimizer().

◆ train()

Dict[str, Union[float, List[float]]] package.experimental.train.session.TrainSession.train (   self,
DataLoader  data_loader,
int  epochs,
float   validation_split = 0.0,
Optional[str]   checkpoint_path = None 
)
Train the model using the given data loader.

Args:
    data_loader: Data loader providing input and expected data.
    epochs (int): Number of epochs to train.
    validation_split (float): Ratio of validation data. Default is 0.0.
    checkpoint_path (str, optional): Path to save training checkpoints.

Returns:
    dict: Timing and performance metrics.

Definition at line 113 of file session.py.

119 ) -> Dict[str, Union[float, List[float]]]:
120 """
121 Train the model using the given data loader.
122
123 Args:
124 data_loader: Data loader providing input and expected data.
125 epochs (int): Number of epochs to train.
126 validation_split (float): Ratio of validation data. Default is 0.0.
127 checkpoint_path (str, optional): Path to save training checkpoints.
128
129 Returns:
130 dict: Timing and performance metrics.
131 """
132 if self.optimizer is None or self.loss is None:
133 raise RuntimeError("Call compile() before train().")
134
135 train_data, val_data = data_loader.split(validation_split)
136 epoch_times: List[float] = []
137
138 for epoch in range(epochs):
139 epoch_start = time.perf_counter()
140 train_loss, io_ms, train_ms = self._run_phase(train_data, train=True)
141 msg = [
142 f"Epoch {epoch+1}/{epochs}", f"Train time: {train_ms:.3f}ms/step",
143 f"IO time: {io_ms:.3f}ms/step", f"Train Loss: {train_loss:.4f}"
144 ]
145
146 if validation_split > 0.0:
147 val_loss, _, _ = self._run_phase(val_data, train=False)
148 msg.append(f"Validation Loss: {val_loss:.4f}")
149 for m in self.metrics:
150 msg.append(f"{m.__class__.__name__}: {m.result():.4f}")
151 m.reset_state()
152
153 epoch_times.append((time.perf_counter() - epoch_start) * 1000)
154 print(" - ".join(msg))
155
156 if checkpoint_path:
157 self.session.train_export_checkpoint(checkpoint_path)
158
159 self.total_time["EXECUTE"] = sum(epoch_times)
160 self.total_time["EPOCH_TIMES"] = epoch_times
161 return self.total_time
162

References package.experimental.train.session.TrainSession._run_phase(), onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, package.experimental.train.session.TrainSession.loss, package.experimental.train.session.TrainSession.metrics, onert_micro::OMTrainingContext.optimizer, package.experimental.train.session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, and package.common.basesession.BaseSession.session.

◆ train_step()

Dict[str, Any] package.experimental.train.session.TrainSession.train_step (   self,
List[np.ndarray]  inputs,
List[np.ndarray]  expecteds 
)
Train the model for a single batch.

Args:
    inputs (list of np.ndarray): List of input arrays for the batch.
    expecteds (list of np.ndarray): List of expected output arrays for the batch.

Returns:
    dict: Loss and metrics values, and train_time in ms.

Definition at line 249 of file session.py.

250 expecteds: List[np.ndarray]) -> Dict[str, Any]:
251 """
252 Train the model for a single batch.
253
254 Args:
255 inputs (list of np.ndarray): List of input arrays for the batch.
256 expecteds (list of np.ndarray): List of expected output arrays for the batch.
257
258 Returns:
259 dict: Loss and metrics values, and train_time in ms.
260 """
261 if self.optimizer is None or self.loss is None:
262 raise RuntimeError("Call `compile()` before `train_step()`")
263 result = self._batch_step(inputs, expecteds, update_weights=True)
264 result["train_time"] = result.pop("time_ms")
265 return result
266

References package.experimental.train.session.TrainSession._batch_step(), onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, package.experimental.train.session.TrainSession.loss, onert_micro::OMTrainingContext.optimizer, package.experimental.train.session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, and onert::backend::train::BackendContext.optimizer().

Field Documentation

◆ loss

◆ metrics

◆ optimizer

◆ train_info


The documentation for this class was generated from the following file: