2from onert.experimental.train
import session, DataLoader, optimizer, losses, metrics
6 parser = argparse.ArgumentParser()
7 parser.add_argument(
'-m',
10 help=
'Path to the nnpackage file or directory')
11 parser.add_argument(
'-i',
14 help=
'Path to the file containing input data (e.g., .npy or raw)')
19 help=
'Path to the file containing label data (e.g., .npy or raw).')
20 parser.add_argument(
'--data_length', required=
True, type=int, help=
'data length')
21 parser.add_argument(
'--backends', default=
'train', help=
'Backends to use')
22 parser.add_argument(
'--batch_size', default=16, type=int, help=
'batch size')
23 parser.add_argument(
'--learning_rate', default=0.01, type=float, help=
'learning rate')
24 parser.add_argument(
'--loss', default=
'mse', choices=[
'mse',
'cce'])
25 parser.add_argument(
'--optimizer', default=
'sgd', choices=[
'sgd',
'adam'])
26 parser.add_argument(
'--loss_reduction_type', default=
'mean', choices=[
'mean',
'sum'])
28 return parser.parse_args()
33 Create an optimizer based on the specified type.
35 optimizer_type (str): The type of optimizer ('SGD' or 'Adam').
36 learning_rate (float): The learning rate for the optimizer.
37 **kwargs: Additional parameters for the optimizer.
39 Optimizer: The created optimizer instance.
41 if optimizer_type.lower() ==
"sgd":
43 elif optimizer_type.lower() ==
"adam":
46 raise ValueError(f
"Unknown optimizer type: {optimizer_type}")
68 Main function to train the model.
71 sess =
session(args.nnpkg, args.backends)
74 input_shape = sess.input_tensorinfo(0).dims
75 label_shape = sess.output_tensorinfo(0).dims
77 input_shape[0] = args.data_length
78 label_shape[0] = args.data_length
80 data_loader = DataLoader(args.input,
83 input_shape=input_shape,
84 expected_shape=label_shape)
91 loss_fn =
createLoss(args.loss, reduction=args.loss_reduction_type)
93 sess.compile(optimizer=opt_fn, loss=loss_fn, batch_size=args.batch_size)
98 metric_aggregates = {metric.__class__.__name__: 0.0
for metric
in mtrs}
101 nums_steps = (args.data_length + args.batch_size - 1) // args.batch_size
102 for idx, (inputs, expecteds)
in enumerate(data_loader):
104 results = sess.train_step(inputs, expecteds)
105 total_loss += sum(results[
'loss'])
108 for metric_name, metric_value
in results[
'metrics'].items():
109 metric_aggregates[metric_name] += metric_value
111 train_time += results[
'train_time']
114 f
"Step {idx + 1}/{nums_steps} - Train time: {results['train_time']:.3f} ms/step - Train Loss: {sum(results['loss']):.4f}"
119 name: value / args.batch_size
120 for name, value
in metric_aggregates.items()
125 print(f
"Average Loss: {total_loss / nums_steps:.4f}")
126 for metric_name, metric_value
in avg_metrics.items():
127 print(f
"{metric_name}: {metric_value:.4f}")
128 print(f
"Average Time: {train_time / nums_steps:.4f} ms/step")
131 print(f
"nnpackage {args.nnpkg.split('/')[-1]} trains successfully.")