7def luci_eval_verify(test_name,
13 tflite_model = os.path.join(tflite_dir, test_name +
".tflite")
14 circle_model = os.path.join(circle_dir, test_name +
".pass.circle")
17 rtolint = int(rtolf32)
18 atolint = int(atolf32)
21 interpreter = tf.lite.Interpreter(tflite_model)
22 interpreter.allocate_tensors()
25 full_signatures = interpreter._get_full_signature_list()
26 full_signatures_outputs_remap =
None
27 if full_signatures !=
None:
28 signature_serving_default = full_signatures.get(
'serving_default',
None)
29 if signature_serving_default !=
None:
30 signature_outputs = signature_serving_default[
'outputs']
32 full_signatures_outputs_remap = []
33 for index, (key, value)
in enumerate(signature_outputs.items()):
34 full_signatures_outputs_remap.append(value)
37 num_inputs = len(interpreter.get_input_details())
38 for i
in range(num_inputs):
39 input_details = interpreter.get_input_details()[i]
40 if input_details[
"dtype"] == np.float32:
41 input_data = np.array(np.random.random_sample(input_details[
"shape"]),
42 input_details[
"dtype"])
43 elif input_details[
"dtype"] == np.uint8:
44 input_data = np.array(np.random.randint(0, 256, size=input_details[
"shape"]),
45 input_details[
"dtype"])
46 elif input_details[
"dtype"] == np.int16:
47 input_data = np.array(np.random.randint(0, 100, size=input_details[
"shape"]),
48 input_details[
"dtype"])
49 elif input_details[
"dtype"] == np.int32:
50 input_data = np.array(np.random.randint(0, 100, size=input_details[
"shape"]),
51 input_details[
"dtype"])
52 elif input_details[
"dtype"] == np.int64:
53 input_data = np.array(np.random.randint(0, 100, size=input_details[
"shape"]),
54 input_details[
"dtype"])
55 elif input_details[
"dtype"] == np.bool_:
56 input_data = np.array(
57 np.random.choice(a=[
True,
False], size=input_details[
"shape"]),
58 input_details[
"dtype"])
60 assert False,
"Unsupported input dtype"
62 interpreter.set_tensor(input_details[
"index"], input_data)
63 input_data.tofile(circle_model +
".input" + str(i))
70 eval_driver, circle_model,
71 str(num_inputs), circle_model +
".input", circle_model +
".output"
76 inpt_output_details = interpreter.get_output_details()
77 for idx
in range(len(inpt_output_details)):
78 output_details = inpt_output_details[idx]
79 output_data = np.fromfile(circle_model +
".output" + str(idx),
80 output_details[
"dtype"])
81 shape_file = open(circle_model +
".output" + str(idx) +
".shape",
'r')
82 output_shape = [int(i)
for i
in shape_file.read().split(
',')]
83 luci_output_data = np.reshape(output_data, output_shape)
84 output_tensor = output_details[
"index"]
85 if full_signatures_outputs_remap !=
None:
86 output_tensor = full_signatures_outputs_remap[idx]
87 intp_output_data = interpreter.get_tensor(output_tensor)
88 err_msg =
"Execution result of " + tflite_model +
" does not match with " + circle_model
89 if output_details[
"dtype"] == np.uint8:
90 assert np.allclose(luci_output_data,
93 atol=atolint), err_msg
94 elif output_details[
"dtype"] == np.float32:
95 assert np.allclose(luci_output_data,
98 atol=atolf32), err_msg
99 elif output_details[
"dtype"] == np.int64:
100 assert np.allclose(luci_output_data,
103 atol=atolint), err_msg
104 elif output_details[
"dtype"] == np.int32:
105 assert np.allclose(luci_output_data,
108 atol=atolint), err_msg
109 elif output_details[
"dtype"] == np.int16:
110 assert np.allclose(luci_output_data,
113 atol=atolint), err_msg
114 elif output_details[
"dtype"] == np.bool_:
115 assert np.allclose(luci_output_data, intp_output_data, rtol=0,
118 assert False,
"Unsupported data type: " + output_details[
"dtype"]
122def test_luci_eval(test_name: str, tflite_dir: str, circle_dir: str,
123 eval_driver_path: str):
124 luci_eval_verify(test_name, tflite_dir, circle_dir, eval_driver_path)