ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
session.TrainSession Class Reference
Collaboration diagram for session.TrainSession:

Public Member Functions

 __init__ (self, nnpackage_path, backends="train")
 
 compile (self, optimizer, loss, metrics=[], batch_size=16)
 
 train (self, data_loader, epochs, validation_split=0.0, checkpoint_path=None)
 
 train_step (self, inputs, expecteds)
 

Data Fields

 total_time
 
 train_info
 
 optimizer
 
 loss
 
 metrics
 

Protected Member Functions

 _print_training_parameters (self)
 
 _run_phase (self, data, train=True)
 
 _check_batch_size (self, data, batch_size, data_type="input")
 

Detailed Description

Class for training and inference using nnfw_session.

Definition at line 15 of file session.py.

Constructor & Destructor Documentation

◆ __init__()

session.TrainSession.__init__ (   self,
  nnpackage_path,
  backends = "train" 
)
Initialize the train session.
Args:
    nnpackage_path (str): Path to the nnpackage file or directory.
    backends (str): Backends to use, default is "train".

Definition at line 19 of file session.py.

19 def __init__(self, nnpackage_path, backends="train"):
20 """
21 Initialize the train session.
22 Args:
23 nnpackage_path (str): Path to the nnpackage file or directory.
24 backends (str): Backends to use, default is "train".
25 """
26 load_start = time.perf_counter()
27 super().__init__(
28 libnnfw_api_pybind.experimental.nnfw_session(nnpackage_path, backends))
29 load_end = time.perf_counter()
30 self.total_time = {'MODEL_LOAD': (load_end - load_start) * 1000}
31 self.train_info = self.session.train_get_traininfo()
32 self.optimizer = None
33 self.loss = None
34 self.metrics = []
35

References session.TrainSession.__init__().

Referenced by session.TrainSession.__init__().

Member Function Documentation

◆ _check_batch_size()

session.TrainSession._check_batch_size (   self,
  data,
  batch_size,
  data_type = "input" 
)
protected
Validate that the batch size of the data matches the configured training batch size.
Args:
    data (list of np.ndarray): The data to validate.
    batch_size (int): The expected batch size.
    data_type (str): A string to indicate whether the data is 'input' or 'expected'.
Raises:
    ValueError: If the batch size does not match the expected value.

Definition at line 236 of file session.py.

236 def _check_batch_size(self, data, batch_size, data_type="input"):
237 """
238 Validate that the batch size of the data matches the configured training batch size.
239 Args:
240 data (list of np.ndarray): The data to validate.
241 batch_size (int): The expected batch size.
242 data_type (str): A string to indicate whether the data is 'input' or 'expected'.
243 Raises:
244 ValueError: If the batch size does not match the expected value.
245 """
246 for i, array in enumerate(data):
247 if array.shape[0] > batch_size:
248 raise ValueError(
249 f"Batch size mismatch for {data_type} data at index {i}: "
250 f"batch size ({array.shape[0]}) does not match the configured "
251 f"training batch size ({batch_size}).")
252

Referenced by session.TrainSession._run_phase(), and session.TrainSession.train_step().

◆ _print_training_parameters()

session.TrainSession._print_training_parameters (   self)
protected
Print the training parameters in a formatted way.

Definition at line 85 of file session.py.

85 def _print_training_parameters(self):
86 """
87 Print the training parameters in a formatted way.
88 """
89 # Get loss function name
90 loss_name = self.loss.__class__.__name__ if self.loss else "Unknown Loss"
91
92 # Get reduction type name from enum value
93 reduction_name = self.train_info.loss_info.reduction_type.name.lower().replace(
94 "_", " ")
95
96 # Get optimizer name
97 optimizer_name = self.optimizer.__class__.__name__ if self.optimizer else "Unknown Optimizer"
98
99 print("== training parameter ==")
100 print(
101 f"- learning_rate = {f'{self.train_info.learning_rate:.4f}'.rstrip('0').rstrip('.')}"
102 )
103 print(f"- batch_size = {self.train_info.batch_size}")
104 print(
105 f"- loss_info = {{loss = {loss_name}, reduction = {reduction_name}}}"
106 )
107 print(f"- optimizer = {optimizer_name}")
108 print(f"- num_of_trainable_ops = {self.train_info.num_of_trainable_ops}")
109 print("========================")
110

References onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, session.TrainSession.loss, onert_micro::OMTrainingContext.optimizer, session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), and session.TrainSession.train_info.

Referenced by session.TrainSession.compile().

◆ _run_phase()

session.TrainSession._run_phase (   self,
  data,
  train = True 
)
protected
Run a training or validation phase.
Args:
    data: Data generator providing input and expected data.
    train (bool): Whether to perform training or validation.
Returns:
    float: Average loss for the phase.

Definition at line 168 of file session.py.

168 def _run_phase(self, data, train=True):
169 """
170 Run a training or validation phase.
171 Args:
172 data: Data generator providing input and expected data.
173 train (bool): Whether to perform training or validation.
174 Returns:
175 float: Average loss for the phase.
176 """
177 total_loss = 0.0
178 num_batches = 0
179
180 io_time = 0
181 train_time = 0
182
183 for inputs, expecteds in data:
184 # Validate batch sizes
185 self._check_batch_size(inputs, self.train_info.batch_size, data_type="input")
186 self._check_batch_size(expecteds,
187 self.train_info.batch_size,
188 data_type="expected")
189
190 set_io_start = time.perf_counter()
191 # Set inputs
192 for i, input_data in enumerate(inputs):
193 self.session.train_set_input(i, input_data)
194
195 # Set expected outputs
196 outputs = []
197 for i, expected_data in enumerate(expecteds):
198 expected = np.array(expected_data,
199 dtype=self.session.output_tensorinfo(i).dtype)
200 self.session.train_set_expected(i, expected)
201
202 output = np.zeros(expected.shape,
203 dtype=self.session.output_tensorinfo(i).dtype)
204 self.session.train_set_output(i, output)
205 assert i == len(outputs)
206 outputs.append(output)
207
208 set_io_end = time.perf_counter()
209
210 # Run training or validation
211 train_start = time.perf_counter()
212 self.session.train(update_weights=train)
213 train_end = time.perf_counter()
214
215 # Accumulate loss
216 batch_loss = sum(
217 self.session.train_get_loss(i) for i in range(len(expecteds)))
218 total_loss += batch_loss
219 num_batches += 1
220
221 # Update metrics
222 if not train:
223 for metric in self.metrics:
224 metric.update_state(outputs, expecteds)
225
226 # Calculate times
227 io_time += (set_io_end - set_io_start)
228 train_time += (train_end - train_start)
229
230 if num_batches > 0:
231 return (total_loss / num_batches, (io_time * 1000) / num_batches,
232 (train_time * 1000) / num_batches)
233 else:
234 return (0.0, 0.0, 0.0)
235

References session.TrainSession._check_batch_size(), session.TrainSession.metrics, validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, session.TrainSession.train(), and session.TrainSession.train_info.

Referenced by session.TrainSession.train().

◆ compile()

session.TrainSession.compile (   self,
  optimizer,
  loss,
  metrics = [],
  batch_size = 16 
)
Compile the session with optimizer, loss, and metrics.
Args:
    optimizer (Optimizer): Optimizer instance or str.
    loss (Loss): Loss instance or str.
    metrics (list): List of metrics to evaluate during training.
    batch_size (int): Number of samples per batch.
Raises:
    ValueError: If the number of metrics does not match the number of model outputs.

Definition at line 36 of file session.py.

36 def compile(self, optimizer, loss, metrics=[], batch_size=16):
37 """
38 Compile the session with optimizer, loss, and metrics.
39 Args:
40 optimizer (Optimizer): Optimizer instance or str.
41 loss (Loss): Loss instance or str.
42 metrics (list): List of metrics to evaluate during training.
43 batch_size (int): Number of samples per batch.
44 Raises:
45 ValueError: If the number of metrics does not match the number of model outputs.
46 """
47 self.optimizer = OptimizerRegistry.create_optimizer(optimizer) if isinstance(
48 optimizer, str) else optimizer
49 self.loss = LossRegistry.create_loss(loss) if isinstance(loss, str) else loss
50 self.metrics = [
51 MetricsRegistry.create_metric(m) if isinstance(m, str) else m for m in metrics
52 ]
53
54 # Validate that all elements in self.metrics are instances of Metric
55 for metric in self.metrics:
56 if not isinstance(metric, Metric):
57 raise TypeError(f"Invalid metric type: {type(metric).__name__}. "
58 "All metrics must inherit from the Metric base class.")
59
60 # Check if the number of metrics matches the number of outputs
61 num_model_outputs = self.session.output_size()
62 if 0 < len(self.metrics) != num_model_outputs:
63 raise ValueError(
64 f"Number of metrics ({len(self.metrics)}) does not match the number of model outputs ({num_model_outputs}). "
65 "Please ensure one metric is provided for each model output.")
66
67 # Set training information
68 self.train_info.learning_rate = optimizer.learning_rate
69 self.train_info.batch_size = batch_size
70 self.train_info.loss_info.loss = LossRegistry.map_loss_function_to_enum(loss)
71 self.train_info.loss_info.reduction_type = loss.reduction
72 self.train_info.opt = OptimizerRegistry.map_optimizer_to_enum(optimizer)
73 self.train_info.num_of_trainable_ops = optimizer.nums_trainable_ops
74 self.session.train_set_traininfo(self.train_info)
75
76 # Print training parameters
77 self._print_training_parameters()
78
79 # Prepare session for training
80 compile_start = time.perf_counter()
81 self.session.train_prepare()
82 compile_end = time.perf_counter()
83 self.total_time["COMPILE"] = (compile_end - compile_start) * 1000
84

References session.TrainSession._print_training_parameters(), onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, session.TrainSession.loss, session.TrainSession.metrics, onert_micro::OMTrainingContext.optimizer, session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, session.TrainSession.total_time, and session.TrainSession.train_info.

◆ train()

session.TrainSession.train (   self,
  data_loader,
  epochs,
  validation_split = 0.0,
  checkpoint_path = None 
)
Train the model using the given data loader.
Args:
    data_loader: A data loader providing input and expected data.
    batch_size (int): Number of samples per batch.
    epochs (int): Number of epochs to train.
    validation_split (float): Ratio of validation data. Default is 0.0 (no validation).
    checkpoint_path (str): Path to save or load the training checkpoint.

Definition at line 111 of file session.py.

111 def train(self, data_loader, epochs, validation_split=0.0, checkpoint_path=None):
112 """
113 Train the model using the given data loader.
114 Args:
115 data_loader: A data loader providing input and expected data.
116 batch_size (int): Number of samples per batch.
117 epochs (int): Number of epochs to train.
118 validation_split (float): Ratio of validation data. Default is 0.0 (no validation).
119 checkpoint_path (str): Path to save or load the training checkpoint.
120 """
121 if self.optimizer is None or self.loss is None:
122 raise RuntimeError(
123 "The training session is not properly configured. "
124 "Please call `compile(optimizer, loss)` before calling `train()`.")
125
126 # Split data into training and validation
127 train_data, val_data = data_loader.split(validation_split)
128
129 # Timings for summary
130 epoch_times = []
131
132 # Training loop
133 for epoch in range(epochs):
134 message = [f"Epoch {epoch + 1}/{epochs}"]
135
136 epoch_start_time = time.perf_counter()
137 # Training phase
138 train_loss, avg_io_time, avg_train_time = self._run_phase(train_data,
139 train=True)
140 message.append(f"Train time: {avg_train_time:.3f}ms/step")
141 message.append(f"IO time: {avg_io_time:.3f}ms/step")
142 message.append(f"Train Loss: {train_loss:.4f}")
143
144 # Validation phase
145 if validation_split > 0.0:
146 val_loss, _, _ = self._run_phase(val_data, train=False)
147 message.append(f"Validation Loss: {val_loss:.4f}")
148
149 # Print metrics
150 for metric in self.metrics:
151 message.append(f"{metric.__class__.__name__}: {metric.result():.4f}")
152 metric.reset_state()
153
154 epoch_time = (time.perf_counter() - epoch_start_time) * 1000
155 epoch_times.append(epoch_time)
156
157 print(" - ".join(message))
158
159 # Save checkpoint
160 if checkpoint_path is not None:
161 self.session.train_export_checkpoint(checkpoint_path)
162
163 self.total_time["EXECUTE"] = sum(epoch_times)
164 self.total_time["EPOCH_TIMES"] = epoch_times
165
166 return self.total_time
167

References session.TrainSession._run_phase(), onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, session.TrainSession.loss, session.TrainSession.metrics, onert_micro::OMTrainingContext.optimizer, session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, and session.TrainSession.total_time.

Referenced by session.TrainSession._run_phase(), and session.TrainSession.train_step().

◆ train_step()

session.TrainSession.train_step (   self,
  inputs,
  expecteds 
)
Train the model for a single batch.
Args:
    inputs (list of np.ndarray): List of input arrays for the batch.
    expecteds (list of np.ndarray): List of expected output arrays for the batch.
Returns:
    dict: A dictionary containing loss and metrics values.

Definition at line 253 of file session.py.

253 def train_step(self, inputs, expecteds):
254 """
255 Train the model for a single batch.
256 Args:
257 inputs (list of np.ndarray): List of input arrays for the batch.
258 expecteds (list of np.ndarray): List of expected output arrays for the batch.
259 Returns:
260 dict: A dictionary containing loss and metrics values.
261 """
262 if self.optimizer is None or self.loss is None:
263 raise RuntimeError(
264 "The training session is not properly configured. "
265 "Please call `compile(optimizer, loss)` before calling `train_step()`.")
266
267 # Validate batch sizes
268 self._check_batch_size(inputs, self.train_info.batch_size, data_type="input")
269 self._check_batch_size(expecteds,
270 self.train_info.batch_size,
271 data_type="expected")
272
273 # Set inputs
274 for i, input_data in enumerate(inputs):
275 self.session.train_set_input(i, input_data)
276
277 # Set expected outputs
278 outputs = []
279 for i, expected_data in enumerate(expecteds):
280 self.session.train_set_expected(i, expected_data)
281 output = np.zeros(expected_data.shape,
282 dtype=self.session.output_tensorinfo(i).dtype)
283 self.session.train_set_output(i, output)
284 outputs.append(output)
285
286 # Run a single training step
287 train_start = time.perf_counter()
288 self.session.train(update_weights=True)
289 train_end = time.perf_counter()
290
291 # Calculate loss
292 losses = [self.session.train_get_loss(i) for i in range(len(expecteds))]
293
294 # Update metrics
295 metric_results = {}
296 for metric in self.metrics:
297 metric.update_state(outputs, expecteds)
298 metric_results[metric.__class__.__name__] = metric.result()
299
300 return {
301 "loss": losses,
302 "metrics": metric_results,
303 "train_time": (train_end - train_start) * 1000
304 }

References session.TrainSession._check_batch_size(), onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, session.TrainSession.loss, session.TrainSession.metrics, onert_micro::OMTrainingContext.optimizer, session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, session.TrainSession.train(), and session.TrainSession.train_info.

Field Documentation

◆ loss

◆ metrics

◆ optimizer

◆ total_time

session.TrainSession.total_time

Definition at line 30 of file session.py.

Referenced by session.TrainSession.compile(), and session.TrainSession.train().

◆ train_info


The documentation for this class was generated from the following file: