2from onert.experimental.train
import session, DataLoader, optimizer, losses, metrics
6 parser = argparse.ArgumentParser()
7 parser.add_argument(
'-m',
10 help=
'Path to the nnpackage file or directory')
11 parser.add_argument(
'-i',
14 help=
'Path to the file containing input data (e.g., .npy or raw)')
19 help=
'Path to the file containing label data (e.g., .npy or raw).')
20 parser.add_argument(
'--data_length', required=
True, type=int, help=
'data length')
21 parser.add_argument(
'--backends', default=
'train', help=
'Backends to use')
22 parser.add_argument(
'--batch_size', default=16, type=int, help=
'batch size')
23 parser.add_argument(
'--epoch', default=5, type=int, help=
'epoch number')
24 parser.add_argument(
'--learning_rate', default=0.01, type=float, help=
'learning rate')
25 parser.add_argument(
'--loss', default=
'mse', choices=[
'mse',
'cce'])
26 parser.add_argument(
'--optimizer', default=
'sgd', choices=[
'sgd',
'adam'])
27 parser.add_argument(
'--loss_reduction_type', default=
'mean', choices=[
'mean',
'sum'])
28 parser.add_argument(
'--validation_split',
31 help=
'validation split rate')
33 return parser.parse_args()
38 Create an optimizer based on the specified type.
40 optimizer_type (str): The type of optimizer ('SGD' or 'Adam').
41 learning_rate (float): The learning rate for the optimizer.
42 **kwargs: Additional parameters for the optimizer.
44 Optimizer: The created optimizer instance.
46 if optimizer_type.lower() ==
"sgd":
48 elif optimizer_type.lower() ==
"adam":
51 raise ValueError(f
"Unknown optimizer type: {optimizer_type}")
73 Main function to train the model.
76 sess =
session(args.nnpkg, args.backends)
79 input_shape = sess.input_tensorinfo(0).dims
80 label_shape = sess.output_tensorinfo(0).dims
82 input_shape[0] = args.data_length
83 label_shape[0] = args.data_length
85 data_loader = DataLoader(args.input,
88 input_shape=input_shape,
89 expected_shape=label_shape)
96 loss_fn =
createLoss(args.loss, reduction=args.loss_reduction_type)
98 sess.compile(optimizer=opt_fn,
100 batch_size=args.batch_size,
104 total_time = sess.train(data_loader,
106 validation_split=args.validation_split,
107 checkpoint_path=
"checkpoint.ckpt")
111 print(f
"MODEL_LOAD takes {total_time['MODEL_LOAD']:.4f} ms")
112 print(f
"COMPILE takes {total_time['COMPILE']:.4f} ms")
113 print(f
"EXECUTE takes {total_time['EXECUTE']:.4f} ms")
114 epoch_times = total_time[
'EPOCH_TIMES']
115 for i, epoch_time
in enumerate(epoch_times):
116 print(f
"- Epoch {i + 1} takes {epoch_time:.4f} ms")
119 print(f
"nnpackage {args.nnpkg.split('/')[-1]} trains successfully.")