18parser = argparse.ArgumentParser()
19parser.add_argument(
'--driver', type=str, required=
True)
20parser.add_argument(
'--name', type=str, required=
True)
21args = parser.parse_args()
24tflite_model = args.name +
".tflite"
25circle_model = args.name +
".circle"
26partition_conn_ini = args.name +
".conn.ini"
27partition_conn_json = args.name +
".conn.json"
28expected_count = args.name +
".excnt"
32 with open(expected_count,
"r")
as expected_count_file:
33 expected_count_line = expected_count_file.readline()
35 expected_count_line = int(expected_count_line)
36 if expected_count_line:
37 with open(partition_conn_json)
as json_file:
38 json_data = json.load(json_file)
39 parts_value = json_data[
"parts"]
40 if len(parts_value) != expected_count_line:
41 print(
"Partitioned model count differs from expected:",
45 print(
"Partitioned model count expected: ", expected_count_line)
47 print(
"Skip expected partitioned model count check: 0")
50 print(
"Skip expected partitioned model count check: error")
53interpreter = tf.lite.Interpreter(tflite_model)
54interpreter.allocate_tensors()
57full_signatures = interpreter._get_full_signature_list()
58full_signatures_outputs_remap =
None
59if full_signatures !=
None:
60 signature_serving_default = full_signatures.get(
'serving_default',
None)
61 if signature_serving_default !=
None:
62 signature_outputs = signature_serving_default[
'outputs']
64 full_signatures_outputs_remap = []
65 for index, (key, value)
in enumerate(signature_outputs.items()):
66 full_signatures_outputs_remap.append(value)
69num_inputs = len(interpreter.get_input_details())
70for i
in range(num_inputs):
71 input_details = interpreter.get_input_details()[i]
72 input_details_dtype = input_details[
"dtype"]
73 input_details_shape = input_details[
"shape"]
74 if input_details_dtype == np.float32:
75 input_data = np.array(np.random.random_sample(input_details_shape),
77 elif input_details_dtype == np.int16:
78 input_data = np.array(np.random.randint(0, 100, size=input_details_shape),
80 elif input_details_dtype == np.uint8:
81 input_data = np.array(np.random.randint(0, 256, size=input_details_shape),
83 elif input_details_dtype == np.bool_:
84 input_data = np.array(np.random.choice(a=[
True,
False], size=input_details_shape),
87 raise SystemExit(
"Unsupported input dtype")
89 interpreter.set_tensor(input_details[
"index"], input_data)
90 input_data.tofile(circle_model +
".input" +
str(i))
97 driver, partition_conn_ini,
98 str(num_inputs), circle_model +
".input", circle_model +
".output"
101for arg
in partition_command:
102 print(
" ", arg,
"\\")
105subprocess.run(partition_command, check=
True)
108inpt_output_details = interpreter.get_output_details()
109for idx
in range(len(inpt_output_details)):
110 output_details = inpt_output_details[idx]
111 output_dtype = output_details[
"dtype"]
112 output_data = np.fromfile(circle_model +
".output" +
str(idx), output_dtype)
113 shape_file = open(circle_model +
".output" +
str(idx) +
".shape",
'r')
114 output_shape = [int(i)
for i
in shape_file.read().split(
',')]
115 luci_output_data = np.reshape(output_data, output_shape)
116 output_tensor = output_details[
"index"]
117 if full_signatures_outputs_remap !=
None:
118 output_tensor = full_signatures_outputs_remap[idx]
119 intp_output_data = interpreter.get_tensor(output_tensor)
121 if output_dtype == np.uint8:
122 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) ==
False:
123 raise SystemExit(
"Execution result of " + tflite_model +
124 " does not match with " + circle_model)
125 elif output_dtype == np.float32:
126 if np.allclose(luci_output_data, intp_output_data, rtol=1.e-5,
127 atol=1.e-5) ==
False:
128 raise SystemExit(
"Execution result of " + tflite_model +
129 " does not match with " + circle_model)
130 elif output_dtype == np.int64:
131 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) ==
False:
132 raise SystemExit(
"Execution result of " + tflite_model +
133 " does not match with " + circle_model)
134 elif output_dtype == np.int32:
135 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) ==
False:
136 raise SystemExit(
"Execution result of " + tflite_model +
137 " does not match with " + circle_model)
138 elif output_dtype == np.int16:
139 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) ==
False:
140 raise SystemExit(
"Execution result of " + tflite_model +
141 " does not match with " + circle_model)
143 raise SystemExit(
"Unsupported data type: ", output_dtype)
145 print(traceback.format_exc())