ONE - On-device Neural Engine
Loading...
Searching...
No Matches
test_luci_eval Namespace Reference

Functions

 luci_eval_verify (test_name, tflite_dir, circle_dir, eval_driver, rtolf32=1e-5, atolf32=1e-5)
 
 test_luci_eval (str test_name, str tflite_dir, str circle_dir, str eval_driver_path)
 
 recover_fromfile (path, test_name, suffix)
 
 recover_inputs (path, test_name)
 
 recover_outputs (path, test_name)
 
 save_binary_inputs (path, test_name, num_inputs, input_shape, input_dtype, input_data)
 
 luci_eval_verify (test_name, binary_path, eval_driver, rtolf32=1e-5, atolf32=1e-5)
 
 test_luci_eval (str default_test_name, str binary_path, str eval_driver_path)
 
 test_luci_eval_tol (str tol_test_name, str binary_path, str eval_driver_path, str rtolf32, str atolf32)
 

Function Documentation

◆ luci_eval_verify() [1/2]

test_luci_eval.luci_eval_verify (   test_name,
  binary_path,
  eval_driver,
  rtolf32 = 1e-5,
  atolf32 = 1e-5 
)

Definition at line 78 of file test_luci_eval.py.

78def luci_eval_verify(test_name, binary_path, eval_driver, rtolf32=1e-5, atolf32=1e-5):
79 circle_model = os.path.join(binary_path, test_name + ".circle")
80
81 num_inputs, input_shape, input_dtype, input_data = recover_inputs(
82 binary_path, test_name)
83 assert num_inputs > 0, "No valid reference input file"
84 save_binary_inputs(binary_path, test_name, num_inputs, input_shape, input_dtype,
85 input_data)
86
87 num_ouputs, output_shape, output_dtype, output_data = recover_outputs(
88 binary_path, test_name)
89 assert num_ouputs > 0, "No valid reference output file"
90
91 # Execute luci interpreter.
92 subprocess.run([
93 eval_driver, circle_model,
94 str(num_inputs), circle_model + ".input", circle_model + ".output"
95 ],
96 check=True)
97
98 # Compare the results.
99 for idx in range(num_ouputs):
100 luci_output_data = np.fromfile(circle_model + ".output" + str(idx),
101 output_dtype[idx])
102 luci_output_data = np.reshape(luci_output_data, output_shape[idx])
103 ref_output_data = np.reshape(output_data[idx], output_shape[idx])
104
105 show_vals_and_stop = False
106 if output_dtype[idx] == "float32":
107 if not np.allclose(
108 luci_output_data, ref_output_data, rtol=rtolf32, atol=atolf32):
109 show_vals_and_stop = True
110 else:
111 assert False, "Unsupported data type: " + output_dtype[idx]
112
113 if show_vals_and_stop:
114 print("\nreference:\n", ref_output_data)
115 print("luci:\n", luci_output_data)
116 message = "Execution result of " + test_name + " does not match with reference"
117 assert False, message
118
119
120# arguments must be in sync with `conftest.py`

References recover_inputs(), recover_outputs(), and save_binary_inputs().

◆ luci_eval_verify() [2/2]

test_luci_eval.luci_eval_verify (   test_name,
  tflite_dir,
  circle_dir,
  eval_driver,
  rtolf32 = 1e-5,
  atolf32 = 1e-5 
)

Definition at line 7 of file test_luci_eval.py.

12 atolf32=1e-5):
13 tflite_model = os.path.join(tflite_dir, test_name + ".tflite")
14 circle_model = os.path.join(circle_dir, test_name + ".after.circle")
15
16 # NOTE reuse f32 value as int value too
17 rtolint = int(rtolf32)
18 atolint = int(atolf32)
19
20 # Build TFLite interpreter.
21 interpreter = tf.lite.Interpreter(tflite_model)
22 interpreter.allocate_tensors()
23
24 # Read SignatureDef and get output tensor id orders for remapping
25 full_signatures = interpreter._get_full_signature_list()
26 full_signatures_outputs_remap = None
27 if full_signatures != None:
28 signature_serving_default = full_signatures.get('serving_default', None)
29 if signature_serving_default != None:
30 signature_outputs = signature_serving_default['outputs']
31
32 full_signatures_outputs_remap = []
33 for index, (key, value) in enumerate(signature_outputs.items()):
34 full_signatures_outputs_remap.append(value)
35
36 # Generate random input data.
37 num_inputs = len(interpreter.get_input_details())
38 for i in range(num_inputs):
39 input_details = interpreter.get_input_details()[i]
40 if input_details["dtype"] == np.float32:
41 input_data = np.array(np.random.random_sample(input_details["shape"]),
42 input_details["dtype"])
43 elif input_details["dtype"] == np.uint8:
44 input_data = np.array(np.random.randint(0, 256, size=input_details["shape"]),
45 input_details["dtype"])
46 elif input_details["dtype"] == np.int16:
47 input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]),
48 input_details["dtype"])
49 elif input_details["dtype"] == np.int32:
50 input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]),
51 input_details["dtype"])
52 elif input_details["dtype"] == np.int64:
53 input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]),
54 input_details["dtype"])
55 elif input_details["dtype"] == np.bool_:
56 input_data = np.array(
57 np.random.choice(a=[True, False], size=input_details["shape"]),
58 input_details["dtype"])
59 else:
60 assert False, "Unsupported input dtype"
61
62 interpreter.set_tensor(input_details["index"], input_data)
63 input_data.tofile(circle_model + ".input" + str(i))
64
65 # Do inference
66 interpreter.invoke()
67
68 # Execute luci interpreter.
69 subprocess.run([
70 eval_driver, circle_model,
71 str(num_inputs), circle_model + ".input", circle_model + ".output"
72 ],
73 check=True)
74
75 # Compare the results.
76 inpt_output_details = interpreter.get_output_details()
77 for idx in range(len(inpt_output_details)):
78 output_details = inpt_output_details[idx]
79 output_data = np.fromfile(circle_model + ".output" + str(idx),
80 output_details["dtype"])
81 shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r')
82 output_shape = [int(i) for i in shape_file.read().split(',')]
83 luci_output_data = np.reshape(output_data, output_shape)
84 output_tensor = output_details["index"]
85 if full_signatures_outputs_remap != None:
86 output_tensor = full_signatures_outputs_remap[idx]
87 intp_output_data = interpreter.get_tensor(output_tensor)
88 err_msg = "Execution result of " + tflite_model + " does not match with " + circle_model
89 if output_details["dtype"] == np.uint8:
90 assert np.allclose(luci_output_data,
91 intp_output_data,
92 rtol=rtolint,
93 atol=atolint), err_msg
94 elif output_details["dtype"] == np.float32:
95 assert np.allclose(luci_output_data,
96 intp_output_data,
97 rtol=rtolf32,
98 atol=atolf32), err_msg
99 elif output_details["dtype"] == np.int64:
100 assert np.allclose(luci_output_data,
101 intp_output_data,
102 rtol=rtolint,
103 atol=atolint), err_msg
104 elif output_details["dtype"] == np.int32:
105 assert np.allclose(luci_output_data,
106 intp_output_data,
107 rtol=rtolint,
108 atol=atolint), err_msg
109 elif output_details["dtype"] == np.int16:
110 assert np.allclose(luci_output_data,
111 intp_output_data,
112 rtol=rtolint,
113 atol=atolint), err_msg
114 elif output_details["dtype"] == np.bool_:
115 assert np.allclose(luci_output_data, intp_output_data, rtol=0,
116 atol=0), err_msg
117 else:
118 assert False, "Unsupported data type: " + output_details["dtype"]
119
120
121# arguments must be in sync with `conftest.py`

Referenced by test_luci_eval(), test_luci_eval(), and test_luci_eval_tol().

◆ recover_fromfile()

test_luci_eval.recover_fromfile (   path,
  test_name,
  suffix 
)

Definition at line 8 of file test_luci_eval.py.

8def recover_fromfile(path, test_name, suffix):
9 # .ref file format
10 # 1'st line is shape, i.e. "2,4"
11 # 2'nd line is dtype, i.e. "float32"
12 # 3'rd line is comma seperated values
13 ref_filename = test_name + ".ref." + suffix
14 ref_datapath = os.path.join(path, ref_filename)
15
16 num_data = 0
17 parse_shape = []
18 parse_dtype = []
19 parse_value = []
20
21 while True:
22 refnum_filepath = ref_datapath + str(num_data)
23 if (not os.path.isfile(refnum_filepath)):
24 break
25 with open(refnum_filepath, "r") as ref_file:
26 lines = ref_file.readlines()
27 assert len(lines) >= 3, "Invalid file: " + ref_filename + str(num_data)
28 print("load reference data from", test_name)
29 shape = [int(i) for i in lines[0].split(",")]
30 dtype = lines[1].strip("\r\n \t")
31 if dtype == "float32":
32 value = [float(i) for i in lines[2].split(",")]
33 else:
34 assert False, "Unsupported data type: " + dtype
35
36 # validate shape and number of elements
37 num_elements = 1
38 for dim in shape:
39 num_elements = num_elements * dim
40 if num_elements != len(value):
41 assert False, "Number of value elements do not match with shape"
42
43 parse_shape.append(shape)
44 parse_dtype.append(dtype)
45 parse_value.append(value)
46
47 num_data = num_data + 1
48
49 return num_data, parse_shape, parse_dtype, parse_value
50
51

Referenced by recover_inputs(), and recover_outputs().

◆ recover_inputs()

test_luci_eval.recover_inputs (   path,
  test_name 
)

Definition at line 52 of file test_luci_eval.py.

52def recover_inputs(path, test_name):
53 return recover_fromfile(path, test_name, "input")
54
55

References recover_fromfile().

Referenced by luci_eval_verify().

◆ recover_outputs()

test_luci_eval.recover_outputs (   path,
  test_name 
)

Definition at line 56 of file test_luci_eval.py.

56def recover_outputs(path, test_name):
57 return recover_fromfile(path, test_name, "output")
58
59
60# save reference data to input files for luci-eval-driver

References recover_fromfile().

Referenced by luci_eval_verify().

◆ save_binary_inputs()

test_luci_eval.save_binary_inputs (   path,
  test_name,
  num_inputs,
  input_shape,
  input_dtype,
  input_data 
)

Definition at line 61 of file test_luci_eval.py.

61def save_binary_inputs(path, test_name, num_inputs, input_shape, input_dtype, input_data):
62 circle_inputpath = os.path.join(path, test_name + ".circle.input")
63 for index in range(0, num_inputs):
64 # reference input value
65 if input_dtype[index] == "float32":
66 nps = np.asarray(input_data[index], dtype=np.float32)
67 nps.tofile(circle_inputpath + str(index))
68 else:
69 assert False, "Unsupported data type: " + input_dtype[index]
70 # reference input shape
71 nps = np.asarray(input_shape[index], dtype=np.short)
72 nps.tofile(circle_inputpath + str(index) + ".shape", sep=",")
73 # reference input dtype
74 with open(circle_inputpath + str(index) + ".dtype", "w") as dtype_file:
75 dtype_file.write(input_dtype[index])
76
77

Referenced by luci_eval_verify().

◆ test_luci_eval() [1/2]

test_luci_eval.test_luci_eval ( str  default_test_name,
str  binary_path,
str  eval_driver_path 
)

Definition at line 121 of file test_luci_eval.py.

121def test_luci_eval(default_test_name: str, binary_path: str, eval_driver_path: str):
122 luci_eval_verify(default_test_name, binary_path, eval_driver_path)
123
124
125# arguments must be in sync with `conftest.py`

References luci_eval_verify().

◆ test_luci_eval() [2/2]

test_luci_eval.test_luci_eval ( str  test_name,
str  tflite_dir,
str  circle_dir,
str  eval_driver_path 
)

Definition at line 122 of file test_luci_eval.py.

123 eval_driver_path: str):
124 luci_eval_verify(test_name, tflite_dir, circle_dir, eval_driver_path)

References luci_eval_verify().

◆ test_luci_eval_tol()

test_luci_eval.test_luci_eval_tol ( str  tol_test_name,
str  binary_path,
str  eval_driver_path,
str  rtolf32,
str  atolf32 
)

Definition at line 126 of file test_luci_eval.py.

127 rtolf32: str, atolf32: str):
128 luci_eval_verify(tol_test_name, binary_path, eval_driver_path, float(rtolf32),
129 float(atolf32))

References luci_eval_verify().