8 parser = argparse.ArgumentParser()
9 parser.add_argument(
'-m',
12 help=
'Path to the nnpackage file or directory')
13 parser.add_argument(
'-i',
16 help=
'Path to the file containing input data (e.g., .npy or raw)')
21 help=
'Path to the file containing label data (e.g., .npy or raw).')
22 parser.add_argument(
'--data_length', required=
True, type=int, help=
'data length')
23 parser.add_argument(
'--backends', default=
'train', help=
'Backends to use')
24 parser.add_argument(
'--batch_size', default=16, type=int, help=
'batch size')
25 parser.add_argument(
'--learning_rate', default=0.01, type=float, help=
'learning rate')
26 parser.add_argument(
'--loss', default=
'mse', choices=[
'mse',
'cce'])
27 parser.add_argument(
'--optimizer', default=
'sgd', choices=[
'sgd',
'adam'])
28 parser.add_argument(
'--loss_reduction_type', default=
'mean', choices=[
'mean',
'sum'])
30 return parser.parse_args()
35 Create an optimizer based on the specified type.
37 optimizer_type (str): The type of optimizer ('SGD' or 'Adam').
38 learning_rate (float): The learning rate for the optimizer.
39 **kwargs: Additional parameters for the optimizer.
41 Optimizer: The created optimizer instance.
43 if optimizer_type.lower() ==
"sgd":
44 return optimizer.SGD(learning_rate=learning_rate, **kwargs)
45 elif optimizer_type.lower() ==
"adam":
46 return optimizer.Adam(learning_rate=learning_rate, **kwargs)
48 raise ValueError(f
"Unknown optimizer type: {optimizer_type}")
70 Main function to train the model.
73 sess = session(args.nnpkg, args.backends)
76 input_shape = sess.input_tensorinfo(0).dims
77 label_shape = sess.output_tensorinfo(0).dims
79 input_shape[0] = args.data_length
80 label_shape[0] = args.data_length
85 input_shape=input_shape,
86 expected_shape=label_shape)
93 loss_fn =
createLoss(args.loss, reduction=args.loss_reduction_type)
95 sess.compile(optimizer=opt_fn, loss=loss_fn, batch_size=args.batch_size)
98 mtrs = [metrics.CategoricalAccuracy()]
100 metric_aggregates = {metric.__class__.__name__: 0.0
for metric
in mtrs}
103 nums_steps = (args.data_length + args.batch_size - 1) // args.batch_size
104 for idx, (inputs, expecteds)
in enumerate(data_loader):
106 results = sess.train_step(inputs, expecteds)
107 total_loss += sum(results[
'loss'])
110 for metric_name, metric_value
in results[
'metrics'].items():
111 metric_aggregates[metric_name] += metric_value
113 train_time += results[
'train_time']
116 f
"Step {idx + 1}/{nums_steps} - Train time: {results['train_time']:.3f} ms/step - Train Loss: {sum(results['loss']):.4f}"
121 name: value / args.batch_size
122 for name, value
in metric_aggregates.items()
127 print(f
"Average Loss: {total_loss / nums_steps:.4f}")
128 for metric_name, metric_value
in avg_metrics.items():
129 print(f
"{metric_name}: {metric_value:.4f}")
130 print(f
"Average Time: {train_time / nums_steps:.4f} ms/step")
133 print(f
"nnpackage {args.nnpkg.split('/')[-1]} trains successfully.")