ONE - On-device Neural Engine
|
Public Member Functions | |
__init__ (self, nnpackage_path, backends="train") | |
compile (self, optimizer, loss, metrics=[], batch_size=16) | |
train (self, data_loader, epochs, validation_split=0.0, checkpoint_path=None) | |
train_step (self, inputs, expecteds) | |
Data Fields | |
total_time | |
train_info | |
optimizer | |
loss | |
metrics | |
Protected Member Functions | |
_print_training_parameters (self) | |
_run_phase (self, data, train=True) | |
_check_batch_size (self, data, batch_size, data_type="input") | |
Class for training and inference using nnfw_session.
Definition at line 15 of file session.py.
session.TrainSession.__init__ | ( | self, | |
nnpackage_path, | |||
backends = "train" |
|||
) |
Initialize the train session. Args: nnpackage_path (str): Path to the nnpackage file or directory. backends (str): Backends to use, default is "train".
Definition at line 19 of file session.py.
References session.TrainSession.__init__().
Referenced by session.TrainSession.__init__().
|
protected |
Validate that the batch size of the data matches the configured training batch size. Args: data (list of np.ndarray): The data to validate. batch_size (int): The expected batch size. data_type (str): A string to indicate whether the data is 'input' or 'expected'. Raises: ValueError: If the batch size does not match the expected value.
Definition at line 236 of file session.py.
Referenced by session.TrainSession._run_phase(), and session.TrainSession.train_step().
|
protected |
Print the training parameters in a formatted way.
Definition at line 85 of file session.py.
References onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, session.TrainSession.loss, onert_micro::OMTrainingContext.optimizer, session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), and session.TrainSession.train_info.
Referenced by session.TrainSession.compile().
|
protected |
Run a training or validation phase. Args: data: Data generator providing input and expected data. train (bool): Whether to perform training or validation. Returns: float: Average loss for the phase.
Definition at line 168 of file session.py.
References session.TrainSession._check_batch_size(), session.TrainSession.metrics, validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, session.TrainSession.train(), and session.TrainSession.train_info.
Referenced by session.TrainSession.train().
session.TrainSession.compile | ( | self, | |
optimizer, | |||
loss, | |||
metrics = [] , |
|||
batch_size = 16 |
|||
) |
Compile the session with optimizer, loss, and metrics. Args: optimizer (Optimizer): Optimizer instance or str. loss (Loss): Loss instance or str. metrics (list): List of metrics to evaluate during training. batch_size (int): Number of samples per batch. Raises: ValueError: If the number of metrics does not match the number of model outputs.
Definition at line 36 of file session.py.
References session.TrainSession._print_training_parameters(), onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, session.TrainSession.loss, session.TrainSession.metrics, onert_micro::OMTrainingContext.optimizer, session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, session.TrainSession.total_time, and session.TrainSession.train_info.
session.TrainSession.train | ( | self, | |
data_loader, | |||
epochs, | |||
validation_split = 0.0 , |
|||
checkpoint_path = None |
|||
) |
Train the model using the given data loader. Args: data_loader: A data loader providing input and expected data. batch_size (int): Number of samples per batch. epochs (int): Number of epochs to train. validation_split (float): Ratio of validation data. Default is 0.0 (no validation). checkpoint_path (str): Path to save or load the training checkpoint.
Definition at line 111 of file session.py.
References session.TrainSession._run_phase(), onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, session.TrainSession.loss, session.TrainSession.metrics, onert_micro::OMTrainingContext.optimizer, session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, and session.TrainSession.total_time.
Referenced by session.TrainSession._run_phase(), and session.TrainSession.train_step().
session.TrainSession.train_step | ( | self, | |
inputs, | |||
expecteds | |||
) |
Train the model for a single batch. Args: inputs (list of np.ndarray): List of input arrays for the batch. expecteds (list of np.ndarray): List of expected output arrays for the batch. Returns: dict: A dictionary containing loss and metrics values.
Definition at line 253 of file session.py.
References session.TrainSession._check_batch_size(), onert_micro::OMTrainingContext.loss, nnfw_loss_info.loss, session.TrainSession.loss, session.TrainSession.metrics, onert_micro::OMTrainingContext.optimizer, session.TrainSession.optimizer, onert::backend::acl_common::AclBackendContext< T_TensorBuilder, T_ConstantInitializer, T_KernelGenerator, T_Optimizer >.optimizer, onert::backend::train::BackendContext.optimizer(), validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, package.common.basesession.BaseSession.session, session.TrainSession.train(), and session.TrainSession.train_info.
session.TrainSession.loss |
Definition at line 33 of file session.py.
Referenced by session.TrainSession._print_training_parameters(), session.TrainSession.compile(), session.TrainSession.train(), and session.TrainSession.train_step().
session.TrainSession.metrics |
Definition at line 34 of file session.py.
Referenced by session.TrainSession._run_phase(), session.TrainSession.compile(), session.TrainSession.train(), and session.TrainSession.train_step().
session.TrainSession.optimizer |
Definition at line 32 of file session.py.
Referenced by session.TrainSession._print_training_parameters(), session.TrainSession.compile(), session.TrainSession.train(), and session.TrainSession.train_step().
session.TrainSession.total_time |
Definition at line 30 of file session.py.
Referenced by session.TrainSession.compile(), and session.TrainSession.train().
session.TrainSession.train_info |
Definition at line 31 of file session.py.
Referenced by session.TrainSession._print_training_parameters(), session.TrainSession._run_phase(), session.TrainSession.compile(), and session.TrainSession.train_step().