ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
train_step_with_dataset.py
Go to the documentation of this file.
1import argparse
2from onert.experimental.train import session, DataLoader, optimizer, losses, metrics
3
4
6 parser = argparse.ArgumentParser()
7 parser.add_argument('-m',
8 '--nnpkg',
9 required=True,
10 help='Path to the nnpackage file or directory')
11 parser.add_argument('-i',
12 '--input',
13 required=True,
14 help='Path to the file containing input data (e.g., .npy or raw)')
15 parser.add_argument(
16 '-l',
17 '--label',
18 required=True,
19 help='Path to the file containing label data (e.g., .npy or raw).')
20 parser.add_argument('--data_length', required=True, type=int, help='data length')
21 parser.add_argument('--backends', default='train', help='Backends to use')
22 parser.add_argument('--batch_size', default=16, type=int, help='batch size')
23 parser.add_argument('--learning_rate', default=0.01, type=float, help='learning rate')
24 parser.add_argument('--loss', default='mse', choices=['mse', 'cce'])
25 parser.add_argument('--optimizer', default='sgd', choices=['sgd', 'adam'])
26 parser.add_argument('--loss_reduction_type', default='mean', choices=['mean', 'sum'])
27
28 return parser.parse_args()
29
30
31def createOptimizer(optimizer_type, learning_rate=0.001, **kwargs):
32 """
33 Create an optimizer based on the specified type.
34 Args:
35 optimizer_type (str): The type of optimizer ('SGD' or 'Adam').
36 learning_rate (float): The learning rate for the optimizer.
37 **kwargs: Additional parameters for the optimizer.
38 Returns:
39 Optimizer: The created optimizer instance.
40 """
41 if optimizer_type.lower() == "sgd":
42 return optimizer.SGD(learning_rate=learning_rate, **kwargs)
43 elif optimizer_type.lower() == "adam":
44 return optimizer.Adam(learning_rate=learning_rate, **kwargs)
45 else:
46 raise ValueError(f"Unknown optimizer type: {optimizer_type}")
47
48
49def createLoss(loss_type, reduction="mean"):
50 """
51 Create a loss function based on the specified type and reduction.
52 Args:
53 loss_type (str): The type of loss function ('mse', 'cce').
54 reduction (str): Reduction type ('mean', 'sum').
55 Returns:
56 object: An instance of the specified loss function.
57 """
58 if loss_type.lower() == "mse":
59 return losses.MeanSquaredError(reduction=reduction)
60 elif loss_type.lower() == "cce":
61 return losses.CategoricalCrossentropy(reduction=reduction)
62 else:
63 raise ValueError(f"Unknown loss type: {loss_type}")
64
65
66def train_steps(args):
67 """
68 Main function to train the model.
69 """
70 # Create session and load nnpackage
71 sess = session(args.nnpkg, args.backends)
72
73 # Load data
74 input_shape = sess.input_tensorinfo(0).dims
75 label_shape = sess.output_tensorinfo(0).dims
76
77 input_shape[0] = args.data_length
78 label_shape[0] = args.data_length
79
80 data_loader = DataLoader(args.input,
81 args.label,
82 args.batch_size,
83 input_shape=input_shape,
84 expected_shape=label_shape)
85 print('Load data')
86
87 # optimizer
88 opt_fn = createOptimizer(args.optimizer, args.learning_rate)
89
90 # loss
91 loss_fn = createLoss(args.loss, reduction=args.loss_reduction_type)
92
93 sess.compile(optimizer=opt_fn, loss=loss_fn, batch_size=args.batch_size)
94
95 # Train model
97 total_loss = 0.0
98 metric_aggregates = {metric.__class__.__name__: 0.0 for metric in mtrs}
99 train_time = 0.0
100
101 nums_steps = (args.data_length + args.batch_size - 1) // args.batch_size
102 for idx, (inputs, expecteds) in enumerate(data_loader):
103 # Train on a single step
104 results = sess.train_step(inputs, expecteds)
105 total_loss += sum(results['loss'])
106
107 # Aggregate metrics
108 for metric_name, metric_value in results['metrics'].items():
109 metric_aggregates[metric_name] += metric_value
110
111 train_time += results['train_time']
112
113 print(
114 f"Step {idx + 1}/{nums_steps} - Train time: {results['train_time']:.3f} ms/step - Train Loss: {sum(results['loss']):.4f}"
115 )
116
117 # Average metrics
118 avg_metrics = {
119 name: value / args.batch_size
120 for name, value in metric_aggregates.items()
121 }
122
123 # Print results
124 print("=" * 35)
125 print(f"Average Loss: {total_loss / nums_steps:.4f}")
126 for metric_name, metric_value in avg_metrics.items():
127 print(f"{metric_name}: {metric_value:.4f}")
128 print(f"Average Time: {train_time / nums_steps:.4f} ms/step")
129 print("=" * 35)
130
131 print(f"nnpackage {args.nnpkg.split('/')[-1]} trains successfully.")
132
133
134if __name__ == "__main__":
135 args = initParse()
136
137 train_steps(args)
createOptimizer(optimizer_type, learning_rate=0.001, **kwargs)
createLoss(loss_type, reduction="mean")