9def part_eval(test_name, bin_dir, circle_part_driver):
10 artifacts_dir = os.path.join(bin_dir, test_name)
11 tflite_model = os.path.join(artifacts_dir, test_name +
".tflite")
12 circle_model = os.path.join(artifacts_dir, test_name +
".circle")
13 partition_conn_ini = os.path.join(artifacts_dir, test_name +
".conn.ini")
14 partition_conn_json = os.path.join(artifacts_dir, test_name +
".conn.json")
15 expected_count = os.path.join(artifacts_dir, test_name +
".excnt")
19 with open(expected_count,
"r")
as expected_count_file:
20 expected_count_line = expected_count_file.readline()
22 expected_count_line = int(expected_count_line)
23 if expected_count_line:
24 with open(partition_conn_json)
as json_file:
25 json_data = json.load(json_file)
26 parts_value = json_data[
"parts"]
27 if len(parts_value) != expected_count_line:
28 print(
"Partitioned model count differs from expected:",
32 print(
"Partitioned model count expected: ", expected_count_line)
34 print(
"Skip expected partitioned model count check: 0")
37 print(
"Skip expected partitioned model count check: error")
40 interpreter = tf.lite.Interpreter(tflite_model)
41 interpreter.allocate_tensors()
44 full_signatures = interpreter._get_full_signature_list()
45 full_signatures_outputs_remap =
None
46 if full_signatures !=
None:
47 signature_serving_default = full_signatures.get(
'serving_default',
None)
48 if signature_serving_default !=
None:
49 signature_outputs = signature_serving_default[
'outputs']
51 full_signatures_outputs_remap = []
52 for index, (key, value)
in enumerate(signature_outputs.items()):
53 full_signatures_outputs_remap.append(value)
56 num_inputs = len(interpreter.get_input_details())
57 for i
in range(num_inputs):
58 input_details = interpreter.get_input_details()[i]
59 if input_details[
"dtype"] == np.float32:
60 input_data = np.array(np.random.random_sample(input_details[
"shape"]),
61 input_details[
"dtype"])
62 elif input_details[
"dtype"] == np.uint8:
63 input_data = np.array(np.random.randint(0, 256, size=input_details[
"shape"]),
64 input_details[
"dtype"])
65 elif input_details[
"dtype"] == np.int16:
66 input_data = np.array(np.random.randint(0, 100, size=input_details[
"shape"]),
67 input_details[
"dtype"])
68 elif input_details[
"dtype"] == np.int32:
69 input_data = np.array(np.random.randint(0, 100, size=input_details[
"shape"]),
70 input_details[
"dtype"])
71 elif input_details[
"dtype"] == np.int64:
72 input_data = np.array(np.random.randint(0, 100, size=input_details[
"shape"]),
73 input_details[
"dtype"])
74 elif input_details[
"dtype"] == np.bool_:
75 input_data = np.array(
76 np.random.choice(a=[
True,
False], size=input_details[
"shape"]),
77 input_details[
"dtype"])
79 assert False,
"Unsupported input dtype"
81 interpreter.set_tensor(input_details[
"index"], input_data)
82 input_data.tofile(circle_model +
".input" + str(i))
89 circle_part_driver, partition_conn_ini,
90 str(num_inputs), circle_model +
".input", circle_model +
".output"
93 for arg
in partition_command:
98 subprocess.run(partition_command, check=
True, cwd=artifacts_dir)
101 inpt_output_details = interpreter.get_output_details()
102 for idx
in range(len(inpt_output_details)):
103 output_details = inpt_output_details[idx]
104 output_data = np.fromfile(circle_model +
".output" + str(idx),
105 output_details[
"dtype"])
106 shape_file = open(circle_model +
".output" + str(idx) +
".shape",
'r')
107 output_shape = [int(i)
for i
in shape_file.read().split(
',')]
108 luci_output_data = np.reshape(output_data, output_shape)
109 output_tensor = output_details[
"index"]
110 if full_signatures_outputs_remap !=
None:
111 output_tensor = full_signatures_outputs_remap[idx]
112 intp_output_data = interpreter.get_tensor(output_tensor)
113 if output_details[
"dtype"] == np.uint8:
115 luci_output_data, intp_output_data, rtol=0, atol=0
116 ),
"Execution result of " + tflite_model +
" does not match with " + circle_model
117 elif output_details[
"dtype"] == np.float32:
119 luci_output_data, intp_output_data, rtol=1.e-5, atol=1.e-5
120 ),
"Execution result of " + tflite_model +
" does not match with " + circle_model
121 elif output_details[
"dtype"] == np.int64:
123 luci_output_data, intp_output_data, rtol=0, atol=0
124 ),
"Execution result of " + tflite_model +
" does not match with " + circle_model
125 elif output_details[
"dtype"] == np.int32:
127 luci_output_data, intp_output_data, rtol=0, atol=0
128 ),
"Execution result of " + tflite_model +
" does not match with " + circle_model
129 elif output_details[
"dtype"] == np.int16:
131 luci_output_data, intp_output_data, rtol=0, atol=0
132 ),
"Execution result of " + tflite_model +
" does not match with " + circle_model
133 elif output_details[
"dtype"] == np.bool_:
135 luci_output_data, intp_output_data, rtol=0, atol=0
136 ),
"Execution result of " + tflite_model +
" does not match with " + circle_model
138 assert False,
"Unsupported data type: " + output_details[
"dtype"]