ONE - On-device Neural Engine
Loading...
Searching...
No Matches
inference_benchmark.py
Go to the documentation of this file.
1import argparse
2import numpy as np
3import psutil
4import os
5from typing import List
6from onert import infer
7# TODO: Import tensorinfo from onert
8from onert.native.libnnfw_api_pybind import tensorinfo
9
10
11def get_memory_usage_mb() -> float:
12 """Get current process memory usage in MB."""
13 process = psutil.Process(os.getpid())
14 return process.memory_info().rss / (1024 * 1024)
15
16
17def parse_shapes(shape_strs: List[str]) -> List[List[int]]:
18 shapes = []
19 for s in shape_strs:
20 try:
21 shapes.append([int(dim) for dim in s.strip().split(",")])
22 except Exception:
23 raise ValueError(f"Invalid shape string: '{s}' (expected: 1,224,224,3 ...)")
24 return shapes
25
26
27def get_validated_input_tensorinfos(sess: infer.session,
28 static_shapes: List[List[int]]) -> List[tensorinfo]:
29 original_infos = sess.get_inputs_tensorinfo()
30 if len(static_shapes) != len(original_infos):
31 raise ValueError(
32 f"Input count mismatch: model expects {len(original_infos)} inputs, but got {len(static_shapes)} shapes"
33 )
34
35 updated_infos: List[tensorinfo] = []
36
37 for i, info in enumerate(original_infos):
38 shape = static_shapes[i]
39 if info.rank != len(shape):
40 raise ValueError(
41 f"Rank mismatch for input {i}: expected rank {info.rank}, got {len(shape)}"
42 )
43 info.dims = shape
44 info.rank = len(shape)
45 updated_infos.append(info)
46
47 return updated_infos
48
49
50def benchmark_inference(nnpackage_path: str, backends: str, input_shapes: List[List[int]],
51 repeat: int):
52 mem_before_kb = get_memory_usage_mb() * 1024
53
54 sess = infer.session(path=nnpackage_path, backends=backends)
55 model_load_kb = get_memory_usage_mb() * 1024 - mem_before_kb
56
58 sess, input_shapes) if input_shapes else sess.get_inputs_tensorinfo()
59
60 # Create dummy input arrays
61 dummy_inputs = []
62 for info in input_infos:
63 shape = tuple(info.dims[:info.rank])
64 dummy_inputs.append(np.random.rand(*shape).astype(info.dtype))
65
66 prepare = total_input = total_output = total_run = 0.0
67
68 # Warmup runs
69 prepare_kb = 0
70 for _ in range(3):
71 outputs, metrics = sess.infer(dummy_inputs, measure=True)
72 del outputs
73 if "prepare_time_ms" in metrics:
74 prepare = metrics["prepare_time_ms"]
75 prepare_kb = get_memory_usage_mb() * 1024 - mem_before_kb
76
77 # Benchmark runs
78 for _ in range(repeat):
79 outputs, metrics = sess.infer(dummy_inputs, measure=True)
80 del outputs
81 total_input += metrics["input_time_ms"]
82 total_run += metrics["run_time_ms"]
83 total_output += metrics["output_time_ms"]
84
85 execute_kb = get_memory_usage_mb() * 1024 - mem_before_kb
86
87 print("======= Inference Benchmark =======")
88 print(f"- Warmup runs : 3")
89 print(f"- Measured runs : {repeat}")
90 print(f"- Prepare : {prepare:.3f} ms")
91 print(f"- Avg I/O : {(total_input + total_output) / repeat:.3f} ms")
92 print(f"- Avg Run : {total_run / repeat:.3f} ms")
93 print("===================================")
94 print("RSS")
95 print(f"- MODEL_LOAD : {model_load_kb:.0f} KB")
96 print(f"- PREPARE : {prepare_kb:.0f} KB")
97 print(f"- EXECUTE : {execute_kb:.0f} KB")
98 print(f"- PEAK : {max(model_load_kb, prepare_kb, execute_kb):.0f} KB")
99 print("===================================")
100
101
102def main():
103 parser = argparse.ArgumentParser(description="ONERT Inference Benchmark")
104 parser.add_argument("nnpackage", type=str, help="Path to .nnpackage directory")
105 parser.add_argument("--backends",
106 type=str,
107 default="cpu",
108 help="Backends to use (default: cpu)")
109 parser.add_argument("--input-shape",
110 nargs="+",
111 help="Input shapes for each input (e.g. 1,224,224,3 1,10)")
112 parser.add_argument("--repeat",
113 type=int,
114 default=5,
115 help="Number of measured inference repetitions")
116
117 args = parser.parse_args()
118 shapes = parse_shapes(args.input_shape) if args.input_shape else None
119
120 benchmark_inference(nnpackage_path=args.nnpackage,
121 backends=args.backends,
122 input_shapes=shapes,
123 repeat=args.repeat)
124
125
126if __name__ == "__main__":
127 main()
benchmark_inference(str nnpackage_path, str backends, List[List[int]] input_shapes, int repeat)
List[List[int]] parse_shapes(List[str] shape_strs)
List[tensorinfo] get_validated_input_tensorinfos(infer.session sess, List[List[int]] static_shapes)