7def luci_eval_verify(test_name, artifacts, eval_driver, rtolf32=1e-5, atolf32=1e-5):
8 tflite_model = os.path.join(artifacts, test_name +
".tflite")
9 circle_model = os.path.join(artifacts, test_name +
".circle")
12 rtolint = int(rtolf32)
13 atolint = int(atolf32)
16 interpreter = tf.lite.Interpreter(tflite_model)
17 interpreter.allocate_tensors()
20 full_signatures = interpreter._get_full_signature_list()
21 full_signatures_outputs_remap =
None
22 if full_signatures !=
None:
23 signature_serving_default = full_signatures.get(
'serving_default',
None)
24 if signature_serving_default !=
None:
25 signature_outputs = signature_serving_default[
'outputs']
27 full_signatures_outputs_remap = []
28 for index, (key, value)
in enumerate(signature_outputs.items()):
29 full_signatures_outputs_remap.append(value)
32 num_inputs = len(interpreter.get_input_details())
33 for i
in range(num_inputs):
34 input_details = interpreter.get_input_details()[i]
35 if input_details[
"dtype"] == np.float32:
36 input_data = np.array(np.random.random_sample(input_details[
"shape"]),
37 input_details[
"dtype"])
38 input_dtype =
"float32"
39 elif input_details[
"dtype"] == np.uint8:
40 input_data = np.array(np.random.randint(0, 256, size=input_details[
"shape"]),
41 input_details[
"dtype"])
43 elif input_details[
"dtype"] == np.int16:
44 input_data = np.array(np.random.randint(0, 100, size=input_details[
"shape"]),
45 input_details[
"dtype"])
47 elif input_details[
"dtype"] == np.int32:
48 input_data = np.array(np.random.randint(0, 100, size=input_details[
"shape"]),
49 input_details[
"dtype"])
51 elif input_details[
"dtype"] == np.int64:
52 input_data = np.array(np.random.randint(0, 100, size=input_details[
"shape"]),
53 input_details[
"dtype"])
55 elif input_details[
"dtype"] == np.bool_:
56 input_data = np.array(
57 np.random.choice(a=[
True,
False], size=input_details[
"shape"]),
58 input_details[
"dtype"])
61 assert False,
"Unsupported input dtype"
63 interpreter.set_tensor(input_details[
"index"], input_data)
64 input_data.tofile(circle_model +
".input" + str(i))
65 input_details[
"shape"].tofile(circle_model +
".input" + str(i) +
".shape",
67 with open(circle_model +
".input" + str(i) +
".dtype",
'w')
as dtype_file:
68 dtype_file.write(input_dtype)
75 eval_driver, circle_model,
76 str(num_inputs), circle_model +
".input", circle_model +
".output"
81 inpt_output_details = interpreter.get_output_details()
82 for idx
in range(len(inpt_output_details)):
83 output_details = inpt_output_details[idx]
84 output_data = np.fromfile(circle_model +
".output" + str(idx),
85 output_details[
"dtype"])
86 shape_file = open(circle_model +
".output" + str(idx) +
".shape",
'r')
87 output_shape = [int(i)
for i
in shape_file.read().split(
',')]
88 luci_output_data = np.reshape(output_data, output_shape)
89 output_tensor = output_details[
"index"]
90 if full_signatures_outputs_remap !=
None:
91 output_tensor = full_signatures_outputs_remap[idx]
92 intp_output_data = interpreter.get_tensor(output_tensor)
93 err_msg =
"Execution result of " + tflite_model +
" does not match with " + circle_model
94 if output_details[
"dtype"] == np.uint8:
95 assert np.allclose(luci_output_data,
98 atol=atolint), err_msg
99 output_dtype =
"uint8"
100 elif output_details[
"dtype"] == np.float32:
101 assert np.allclose(luci_output_data,
104 atol=atolf32), err_msg
105 output_dtype =
"float32"
106 elif output_details[
"dtype"] == np.int64:
107 assert np.allclose(luci_output_data,
110 atol=atolint), err_msg
111 output_dtype =
"int64"
112 elif output_details[
"dtype"] == np.int32:
113 assert np.allclose(luci_output_data,
116 atol=atolint), err_msg
117 output_dtype =
"int32"
118 elif output_details[
"dtype"] == np.int16:
119 assert np.allclose(luci_output_data,
122 atol=atolint), err_msg
123 output_dtype =
"int16"
124 elif output_details[
"dtype"] == np.bool_:
125 assert np.allclose(luci_output_data, intp_output_data, rtol=0,
127 output_dtype =
"bool"
129 assert False,
"Unsupported data type: " + output_details[
"dtype"]
132 with open(circle_model +
".output" + str(idx) +
".dtype",
'w')
as dtype_file:
133 dtype_file.write(output_dtype)
137def test_luci_eval(default_test_name: str, artifacts_path: str, eval_driver_path: str):
138 luci_eval_verify(default_test_name, artifacts_path, eval_driver_path)
142def test_luci_eval_tol(tol_test_name: str, artifacts_path: str, eval_driver_path: str,
143 rtolf32: str, atolf32: str):
144 luci_eval_verify(tol_test_name, artifacts_path, eval_driver_path, float(rtolf32),