ONE - On-device Neural Engine
Loading...
Searching...
No Matches
train_step_with_dataset.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3import argparse
4from onert.experimental.train import session, DataLoader, optimizer, losses, metrics
5
6
8 parser = argparse.ArgumentParser()
9 parser.add_argument('-m',
10 '--nnpkg',
11 required=True,
12 help='Path to the nnpackage file or directory')
13 parser.add_argument('-i',
14 '--input',
15 required=True,
16 help='Path to the file containing input data (e.g., .npy or raw)')
17 parser.add_argument(
18 '-l',
19 '--label',
20 required=True,
21 help='Path to the file containing label data (e.g., .npy or raw).')
22 parser.add_argument('--data_length', required=True, type=int, help='data length')
23 parser.add_argument('--backends', default='train', help='Backends to use')
24 parser.add_argument('--batch_size', default=16, type=int, help='batch size')
25 parser.add_argument('--learning_rate', default=0.01, type=float, help='learning rate')
26 parser.add_argument('--loss', default='mse', choices=['mse', 'cce'])
27 parser.add_argument('--optimizer', default='sgd', choices=['sgd', 'adam'])
28 parser.add_argument('--loss_reduction_type', default='mean', choices=['mean', 'sum'])
29
30 return parser.parse_args()
31
32
33def createOptimizer(optimizer_type, learning_rate=0.001, **kwargs):
34 """
35 Create an optimizer based on the specified type.
36 Args:
37 optimizer_type (str): The type of optimizer ('SGD' or 'Adam').
38 learning_rate (float): The learning rate for the optimizer.
39 **kwargs: Additional parameters for the optimizer.
40 Returns:
41 Optimizer: The created optimizer instance.
42 """
43 if optimizer_type.lower() == "sgd":
44 return optimizer.SGD(learning_rate=learning_rate, **kwargs)
45 elif optimizer_type.lower() == "adam":
46 return optimizer.Adam(learning_rate=learning_rate, **kwargs)
47 else:
48 raise ValueError(f"Unknown optimizer type: {optimizer_type}")
49
50
51def createLoss(loss_type, reduction="mean"):
52 """
53 Create a loss function based on the specified type and reduction.
54 Args:
55 loss_type (str): The type of loss function ('mse', 'cce').
56 reduction (str): Reduction type ('mean', 'sum').
57 Returns:
58 object: An instance of the specified loss function.
59 """
60 if loss_type.lower() == "mse":
61 return losses.MeanSquaredError(reduction=reduction)
62 elif loss_type.lower() == "cce":
63 return losses.CategoricalCrossentropy(reduction=reduction)
64 else:
65 raise ValueError(f"Unknown loss type: {loss_type}")
66
67
68def train_steps(args):
69 """
70 Main function to train the model.
71 """
72 # Create session and load nnpackage
73 sess = session(args.nnpkg, args.backends)
74
75 # Load data
76 input_shape = sess.input_tensorinfo(0).dims
77 label_shape = sess.output_tensorinfo(0).dims
78
79 input_shape[0] = args.data_length
80 label_shape[0] = args.data_length
81
82 data_loader = DataLoader(args.input,
83 args.label,
84 args.batch_size,
85 input_shape=input_shape,
86 expected_shape=label_shape)
87 print('Load data')
88
89 # optimizer
90 opt_fn = createOptimizer(args.optimizer, args.learning_rate)
91
92 # loss
93 loss_fn = createLoss(args.loss, reduction=args.loss_reduction_type)
94
95 sess.compile(optimizer=opt_fn, loss=loss_fn, batch_size=args.batch_size)
96
97 # Train model
98 mtrs = [metrics.CategoricalAccuracy()]
99 total_loss = 0.0
100 metric_aggregates = {metric.__class__.__name__: 0.0 for metric in mtrs}
101 train_time = 0.0
102
103 nums_steps = (args.data_length + args.batch_size - 1) // args.batch_size
104 for idx, (inputs, expecteds) in enumerate(data_loader):
105 # Train on a single step
106 results = sess.train_step(inputs, expecteds)
107 total_loss += sum(results['loss'])
108
109 # Aggregate metrics
110 for metric_name, metric_value in results['metrics'].items():
111 metric_aggregates[metric_name] += metric_value
112
113 train_time += results['train_time']
114
115 print(
116 f"Step {idx + 1}/{nums_steps} - Train time: {results['train_time']:.3f} ms/step - Train Loss: {sum(results['loss']):.4f}"
117 )
118
119 # Average metrics
120 avg_metrics = {
121 name: value / args.batch_size
122 for name, value in metric_aggregates.items()
123 }
124
125 # Print results
126 print("=" * 35)
127 print(f"Average Loss: {total_loss / nums_steps:.4f}")
128 for metric_name, metric_value in avg_metrics.items():
129 print(f"{metric_name}: {metric_value:.4f}")
130 print(f"Average Time: {train_time / nums_steps:.4f} ms/step")
131 print("=" * 35)
132
133 print(f"nnpackage {args.nnpkg.split('/')[-1]} trains successfully.")
134
135
136if __name__ == "__main__":
137 args = initParse()
138
139 train_steps(args)
createOptimizer(optimizer_type, learning_rate=0.001, **kwargs)
createLoss(loss_type, reduction="mean")