ONE - On-device Neural Engine
Loading...
Searching...
No Matches
test_luci_eval.py
Go to the documentation of this file.
1import numpy as np
2import subprocess
3import os
4
5
6# read input/output data files model_name.ref.input* and
7# model_name.ref.output* and return the contents
8def recover_fromfile(path, test_name, suffix):
9 # .ref file format
10 # 1'st line is shape, i.e. "2,4"
11 # 2'nd line is dtype, i.e. "float32"
12 # 3'rd line is comma seperated values
13 ref_filename = test_name + ".ref." + suffix
14 ref_datapath = os.path.join(path, ref_filename)
15
16 num_data = 0
17 parse_shape = []
18 parse_dtype = []
19 parse_value = []
20
21 while True:
22 refnum_filepath = ref_datapath + str(num_data)
23 if (not os.path.isfile(refnum_filepath)):
24 break
25 with open(refnum_filepath, "r") as ref_file:
26 lines = ref_file.readlines()
27 assert len(lines) >= 3, "Invalid file: " + ref_filename + str(num_data)
28 print("load reference data from", test_name)
29 shape = [int(i) for i in lines[0].split(",")]
30 dtype = lines[1].strip("\r\n \t")
31 if dtype == "float32":
32 value = [float(i) for i in lines[2].split(",")]
33 else:
34 assert False, "Unsupported data type: " + dtype
35
36 # validate shape and number of elements
37 num_elements = 1
38 for dim in shape:
39 num_elements = num_elements * dim
40 if num_elements != len(value):
41 assert False, "Number of value elements do not match with shape"
42
43 parse_shape.append(shape)
44 parse_dtype.append(dtype)
45 parse_value.append(value)
46
47 num_data = num_data + 1
48
49 return num_data, parse_shape, parse_dtype, parse_value
50
51
52def recover_inputs(path, test_name):
53 return recover_fromfile(path, test_name, "input")
54
55
56def recover_outputs(path, test_name):
57 return recover_fromfile(path, test_name, "output")
58
59
60# save reference data to input files for luci-eval-driver
61def save_binary_inputs(path, test_name, num_inputs, input_shape, input_dtype, input_data):
62 circle_inputpath = os.path.join(path, test_name + ".circle.input")
63 for index in range(0, num_inputs):
64 # reference input value
65 if input_dtype[index] == "float32":
66 nps = np.asarray(input_data[index], dtype=np.float32)
67 nps.tofile(circle_inputpath + str(index))
68 else:
69 assert False, "Unsupported data type: " + input_dtype[index]
70 # reference input shape
71 nps = np.asarray(input_shape[index], dtype=np.short)
72 nps.tofile(circle_inputpath + str(index) + ".shape", sep=",")
73 # reference input dtype
74 with open(circle_inputpath + str(index) + ".dtype", "w") as dtype_file:
75 dtype_file.write(input_dtype[index])
76
77
78def luci_eval_verify(test_name, binary_path, eval_driver, rtolf32=1e-5, atolf32=1e-5):
79 circle_model = os.path.join(binary_path, test_name + ".circle")
80
81 num_inputs, input_shape, input_dtype, input_data = recover_inputs(
82 binary_path, test_name)
83 assert num_inputs > 0, "No valid reference input file"
84 save_binary_inputs(binary_path, test_name, num_inputs, input_shape, input_dtype,
85 input_data)
86
87 num_ouputs, output_shape, output_dtype, output_data = recover_outputs(
88 binary_path, test_name)
89 assert num_ouputs > 0, "No valid reference output file"
90
91 # Execute luci interpreter.
92 subprocess.run([
93 eval_driver, circle_model,
94 str(num_inputs), circle_model + ".input", circle_model + ".output"
95 ],
96 check=True)
97
98 # Compare the results.
99 for idx in range(num_ouputs):
100 luci_output_data = np.fromfile(circle_model + ".output" + str(idx),
101 output_dtype[idx])
102 luci_output_data = np.reshape(luci_output_data, output_shape[idx])
103 ref_output_data = np.reshape(output_data[idx], output_shape[idx])
104
105 show_vals_and_stop = False
106 if output_dtype[idx] == "float32":
107 if not np.allclose(
108 luci_output_data, ref_output_data, rtol=rtolf32, atol=atolf32):
109 show_vals_and_stop = True
110 else:
111 assert False, "Unsupported data type: " + output_dtype[idx]
112
113 if show_vals_and_stop:
114 print("\nreference:\n", ref_output_data)
115 print("luci:\n", luci_output_data)
116 message = "Execution result of " + test_name + " does not match with reference"
117 assert False, message
118
119
120# arguments must be in sync with `conftest.py`
121def test_luci_eval(default_test_name: str, binary_path: str, eval_driver_path: str):
122 luci_eval_verify(default_test_name, binary_path, eval_driver_path)
123
124
125# arguments must be in sync with `conftest.py`
126def test_luci_eval_tol(tol_test_name: str, binary_path: str, eval_driver_path: str,
127 rtolf32: str, atolf32: str):
128 luci_eval_verify(tol_test_name, binary_path, eval_driver_path, float(rtolf32),
129 float(atolf32))
recover_fromfile(path, test_name, suffix)
save_binary_inputs(path, test_name, num_inputs, input_shape, input_dtype, input_data)
test_luci_eval_tol(str tol_test_name, str binary_path, str eval_driver_path, str rtolf32, str atolf32)
recover_inputs(path, test_name)
recover_outputs(path, test_name)
luci_eval_verify(test_name, tflite_dir, circle_dir, eval_driver, rtolf32=1e-5, atolf32=1e-5)