ONE - On-device Neural Engine
Loading...
Searching...
No Matches
train_with_dataset.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3import argparse
4from onert.experimental.train import session, DataLoader, optimizer, losses, metrics
5
6
8 parser = argparse.ArgumentParser()
9 parser.add_argument('-m',
10 '--nnpkg',
11 required=True,
12 help='Path to the nnpackage file or directory')
13 parser.add_argument('-i',
14 '--input',
15 required=True,
16 help='Path to the file containing input data (e.g., .npy or raw)')
17 parser.add_argument(
18 '-l',
19 '--label',
20 required=True,
21 help='Path to the file containing label data (e.g., .npy or raw).')
22 parser.add_argument('--data_length', required=True, type=int, help='data length')
23 parser.add_argument('--backends', default='train', help='Backends to use')
24 parser.add_argument('--batch_size', default=16, type=int, help='batch size')
25 parser.add_argument('--epoch', default=5, type=int, help='epoch number')
26 parser.add_argument('--learning_rate', default=0.01, type=float, help='learning rate')
27 parser.add_argument('--loss', default='mse', choices=['mse', 'cce'])
28 parser.add_argument('--optimizer', default='sgd', choices=['sgd', 'adam'])
29 parser.add_argument('--loss_reduction_type', default='mean', choices=['mean', 'sum'])
30 parser.add_argument('--validation_split',
31 default=0.0,
32 type=float,
33 help='validation split rate')
34
35 return parser.parse_args()
36
37
38def createOptimizer(optimizer_type, learning_rate=0.001, **kwargs):
39 """
40 Create an optimizer based on the specified type.
41 Args:
42 optimizer_type (str): The type of optimizer ('SGD' or 'Adam').
43 learning_rate (float): The learning rate for the optimizer.
44 **kwargs: Additional parameters for the optimizer.
45 Returns:
46 Optimizer: The created optimizer instance.
47 """
48 if optimizer_type.lower() == "sgd":
49 return optimizer.SGD(learning_rate=learning_rate, **kwargs)
50 elif optimizer_type.lower() == "adam":
51 return optimizer.Adam(learning_rate=learning_rate, **kwargs)
52 else:
53 raise ValueError(f"Unknown optimizer type: {optimizer_type}")
54
55
56def createLoss(loss_type, reduction="mean"):
57 """
58 Create a loss function based on the specified type and reduction.
59 Args:
60 loss_type (str): The type of loss function ('mse', 'cce').
61 reduction (str): Reduction type ('mean', 'sum').
62 Returns:
63 object: An instance of the specified loss function.
64 """
65 if loss_type.lower() == "mse":
66 return losses.MeanSquaredError(reduction=reduction)
67 elif loss_type.lower() == "cce":
68 return losses.CategoricalCrossentropy(reduction=reduction)
69 else:
70 raise ValueError(f"Unknown loss type: {loss_type}")
71
72
73def train(args):
74 """
75 Main function to train the model.
76 """
77 # Create session and load nnpackage
78 sess = session(args.nnpkg, args.backends)
79
80 # Load data
81 input_shape = sess.input_tensorinfo(0).dims
82 label_shape = sess.output_tensorinfo(0).dims
83
84 input_shape[0] = args.data_length
85 label_shape[0] = args.data_length
86
87 data_loader = DataLoader(args.input,
88 args.label,
89 args.batch_size,
90 input_shape=input_shape,
91 expected_shape=label_shape)
92 print('Load data')
93
94 # optimizer
95 opt_fn = createOptimizer(args.optimizer, args.learning_rate)
96
97 # loss
98 loss_fn = createLoss(args.loss, reduction=args.loss_reduction_type)
99
100 sess.compile(optimizer=opt_fn,
101 loss=loss_fn,
102 batch_size=args.batch_size,
103 metrics=[metrics.CategoricalAccuracy()])
104
105 # Train model
106 total_time = sess.train(data_loader,
107 epochs=args.epoch,
108 validation_split=args.validation_split,
109 checkpoint_path="checkpoint.ckpt")
110
111 # Print timing summary
112 print("=" * 35)
113 print(f"MODEL_LOAD takes {total_time['MODEL_LOAD']:.4f} ms")
114 print(f"COMPILE takes {total_time['COMPILE']:.4f} ms")
115 print(f"EXECUTE takes {total_time['EXECUTE']:.4f} ms")
116 epoch_times = total_time['EPOCH_TIMES']
117 for i, epoch_time in enumerate(epoch_times):
118 print(f"- Epoch {i + 1} takes {epoch_time:.4f} ms")
119 print("=" * 35)
120
121 print(f"nnpackage {args.nnpkg.split('/')[-1]} trains successfully.")
122
123
124if __name__ == "__main__":
125 args = initParse()
126
127 train(args)
createLoss(loss_type, reduction="mean")
createOptimizer(optimizer_type, learning_rate=0.001, **kwargs)