28 static_shapes: List[List[int]]) -> List[tensorinfo]:
29 original_infos = sess.get_inputs_tensorinfo()
30 if len(static_shapes) != len(original_infos):
32 f
"Input count mismatch: model expects {len(original_infos)} inputs, but got {len(static_shapes)} shapes"
35 updated_infos: List[tensorinfo] = []
37 for i, info
in enumerate(original_infos):
38 shape = static_shapes[i]
39 if info.rank != len(shape):
41 f
"Rank mismatch for input {i}: expected rank {info.rank}, got {len(shape)}"
44 info.rank = len(shape)
45 updated_infos.append(info)
54 sess = infer.session(path=nnpackage_path, backends=backends)
58 sess, input_shapes)
if input_shapes
else sess.get_inputs_tensorinfo()
62 for info
in input_infos:
63 shape = tuple(info.dims[:info.rank])
64 dummy_inputs.append(np.random.rand(*shape).astype(info.dtype))
66 prepare = total_input = total_output = total_run = 0.0
71 outputs, metrics = sess.infer(dummy_inputs, measure=
True)
73 if "prepare_time_ms" in metrics:
74 prepare = metrics[
"prepare_time_ms"]
78 for _
in range(repeat):
79 outputs, metrics = sess.infer(dummy_inputs, measure=
True)
81 total_input += metrics[
"input_time_ms"]
82 total_run += metrics[
"run_time_ms"]
83 total_output += metrics[
"output_time_ms"]
87 print(
"======= Inference Benchmark =======")
88 print(f
"- Warmup runs : 3")
89 print(f
"- Measured runs : {repeat}")
90 print(f
"- Prepare : {prepare:.3f} ms")
91 print(f
"- Avg I/O : {(total_input + total_output) / repeat:.3f} ms")
92 print(f
"- Avg Run : {total_run / repeat:.3f} ms")
93 print(
"===================================")
95 print(f
"- MODEL_LOAD : {model_load_kb:.0f} KB")
96 print(f
"- PREPARE : {prepare_kb:.0f} KB")
97 print(f
"- EXECUTE : {execute_kb:.0f} KB")
98 print(f
"- PEAK : {max(model_load_kb, prepare_kb, execute_kb):.0f} KB")
99 print(
"===================================")