8 parser = argparse.ArgumentParser()
9 parser.add_argument(
'-m',
12 help=
'Path to the nnpackage file or directory')
13 parser.add_argument(
'-i',
16 help=
'Path to the file containing input data (e.g., .npy or raw)')
21 help=
'Path to the file containing label data (e.g., .npy or raw).')
22 parser.add_argument(
'--data_length', required=
True, type=int, help=
'data length')
23 parser.add_argument(
'--backends', default=
'train', help=
'Backends to use')
24 parser.add_argument(
'--batch_size', default=16, type=int, help=
'batch size')
25 parser.add_argument(
'--epoch', default=5, type=int, help=
'epoch number')
26 parser.add_argument(
'--learning_rate', default=0.01, type=float, help=
'learning rate')
27 parser.add_argument(
'--loss', default=
'mse', choices=[
'mse',
'cce'])
28 parser.add_argument(
'--optimizer', default=
'sgd', choices=[
'sgd',
'adam'])
29 parser.add_argument(
'--loss_reduction_type', default=
'mean', choices=[
'mean',
'sum'])
30 parser.add_argument(
'--validation_split',
33 help=
'validation split rate')
35 return parser.parse_args()
40 Create an optimizer based on the specified type.
42 optimizer_type (str): The type of optimizer ('SGD' or 'Adam').
43 learning_rate (float): The learning rate for the optimizer.
44 **kwargs: Additional parameters for the optimizer.
46 Optimizer: The created optimizer instance.
48 if optimizer_type.lower() ==
"sgd":
49 return optimizer.SGD(learning_rate=learning_rate, **kwargs)
50 elif optimizer_type.lower() ==
"adam":
51 return optimizer.Adam(learning_rate=learning_rate, **kwargs)
53 raise ValueError(f
"Unknown optimizer type: {optimizer_type}")
75 Main function to train the model.
78 sess = session(args.nnpkg, args.backends)
81 input_shape = sess.input_tensorinfo(0).dims
82 label_shape = sess.output_tensorinfo(0).dims
84 input_shape[0] = args.data_length
85 label_shape[0] = args.data_length
90 input_shape=input_shape,
91 expected_shape=label_shape)
98 loss_fn =
createLoss(args.loss, reduction=args.loss_reduction_type)
100 sess.compile(optimizer=opt_fn,
102 batch_size=args.batch_size,
103 metrics=[metrics.CategoricalAccuracy()])
106 total_time = sess.train(data_loader,
108 validation_split=args.validation_split,
109 checkpoint_path=
"checkpoint.ckpt")
113 print(f
"MODEL_LOAD takes {total_time['MODEL_LOAD']:.4f} ms")
114 print(f
"COMPILE takes {total_time['COMPILE']:.4f} ms")
115 print(f
"EXECUTE takes {total_time['EXECUTE']:.4f} ms")
116 epoch_times = total_time[
'EPOCH_TIMES']
117 for i, epoch_time
in enumerate(epoch_times):
118 print(f
"- Epoch {i + 1} takes {epoch_time:.4f} ms")
121 print(f
"nnpackage {args.nnpkg.split('/')[-1]} trains successfully.")