ONE - On-device Neural Engine
Loading...
Searching...
No Matches
test_circle_part_value.py
Go to the documentation of this file.
1import numpy as np
2import tensorflow as tf
3import subprocess
4import os
5import json
6
7
8# Compares the execution result of TFLite interpreter and partitioned model(s) from a circle model.
9def part_eval(test_name, bin_dir, circle_part_driver):
10 artifacts_dir = os.path.join(bin_dir, test_name)
11 tflite_model = os.path.join(artifacts_dir, test_name + ".tflite")
12 circle_model = os.path.join(artifacts_dir, test_name + ".circle")
13 partition_conn_ini = os.path.join(artifacts_dir, test_name + ".conn.ini")
14 partition_conn_json = os.path.join(artifacts_dir, test_name + ".conn.json")
15 expected_count = os.path.join(artifacts_dir, test_name + ".excnt")
16
17 # Check expected count of models from partitioning
18 try:
19 with open(expected_count, "r") as expected_count_file:
20 expected_count_line = expected_count_file.readline()
21
22 expected_count_line = int(expected_count_line)
23 if expected_count_line:
24 with open(partition_conn_json) as json_file:
25 json_data = json.load(json_file)
26 parts_value = json_data["parts"]
27 if len(parts_value) != expected_count_line:
28 print("Partitioned model count differs from expected:",
29 expected_count_line)
30 assert False
31
32 print("Partitioned model count expected: ", expected_count_line)
33 else:
34 print("Skip expected partitioned model count check: 0")
35
36 except:
37 print("Skip expected partitioned model count check: error")
38
39 # Build TFLite interpreter.
40 interpreter = tf.lite.Interpreter(tflite_model)
41 interpreter.allocate_tensors()
42
43 # Read SignatureDef and get output tensor id orders for remapping
44 full_signatures = interpreter._get_full_signature_list()
45 full_signatures_outputs_remap = None
46 if full_signatures != None:
47 signature_serving_default = full_signatures.get('serving_default', None)
48 if signature_serving_default != None:
49 signature_outputs = signature_serving_default['outputs']
50
51 full_signatures_outputs_remap = []
52 for index, (key, value) in enumerate(signature_outputs.items()):
53 full_signatures_outputs_remap.append(value)
54
55 # Generate random input data.
56 num_inputs = len(interpreter.get_input_details())
57 for i in range(num_inputs):
58 input_details = interpreter.get_input_details()[i]
59 if input_details["dtype"] == np.float32:
60 input_data = np.array(np.random.random_sample(input_details["shape"]),
61 input_details["dtype"])
62 elif input_details["dtype"] == np.uint8:
63 input_data = np.array(np.random.randint(0, 256, size=input_details["shape"]),
64 input_details["dtype"])
65 elif input_details["dtype"] == np.int16:
66 input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]),
67 input_details["dtype"])
68 elif input_details["dtype"] == np.int32:
69 input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]),
70 input_details["dtype"])
71 elif input_details["dtype"] == np.int64:
72 input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]),
73 input_details["dtype"])
74 elif input_details["dtype"] == np.bool_:
75 input_data = np.array(
76 np.random.choice(a=[True, False], size=input_details["shape"]),
77 input_details["dtype"])
78 else:
79 assert False, "Unsupported input dtype"
80
81 interpreter.set_tensor(input_details["index"], input_data)
82 input_data.tofile(circle_model + ".input" + str(i))
83
84 # Do inference
85 interpreter.invoke()
86
87 # Execute circle-part-driver.
88 partition_command = [
89 circle_part_driver, partition_conn_ini,
90 str(num_inputs), circle_model + ".input", circle_model + ".output"
91 ]
92 print("Run: ")
93 for arg in partition_command:
94 print(" ", arg, "\\")
95 print("", flush=True)
96
97 # working directory into the folder as ini has relative filename of the model
98 subprocess.run(partition_command, check=True, cwd=artifacts_dir)
99
100 # Compare the results.
101 inpt_output_details = interpreter.get_output_details()
102 for idx in range(len(inpt_output_details)):
103 output_details = inpt_output_details[idx]
104 output_data = np.fromfile(circle_model + ".output" + str(idx),
105 output_details["dtype"])
106 shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r')
107 output_shape = [int(i) for i in shape_file.read().split(',')]
108 luci_output_data = np.reshape(output_data, output_shape)
109 output_tensor = output_details["index"]
110 if full_signatures_outputs_remap != None:
111 output_tensor = full_signatures_outputs_remap[idx]
112 intp_output_data = interpreter.get_tensor(output_tensor)
113 if output_details["dtype"] == np.uint8:
114 assert np.allclose(
115 luci_output_data, intp_output_data, rtol=0, atol=0
116 ), "Execution result of " + tflite_model + " does not match with " + circle_model
117 elif output_details["dtype"] == np.float32:
118 assert np.allclose(
119 luci_output_data, intp_output_data, rtol=1.e-5, atol=1.e-5
120 ), "Execution result of " + tflite_model + " does not match with " + circle_model
121 elif output_details["dtype"] == np.int64:
122 assert np.allclose(
123 luci_output_data, intp_output_data, rtol=0, atol=0
124 ), "Execution result of " + tflite_model + " does not match with " + circle_model
125 elif output_details["dtype"] == np.int32:
126 assert np.allclose(
127 luci_output_data, intp_output_data, rtol=0, atol=0
128 ), "Execution result of " + tflite_model + " does not match with " + circle_model
129 elif output_details["dtype"] == np.int16:
130 assert np.allclose(
131 luci_output_data, intp_output_data, rtol=0, atol=0
132 ), "Execution result of " + tflite_model + " does not match with " + circle_model
133 elif output_details["dtype"] == np.bool_:
134 assert np.allclose(
135 luci_output_data, intp_output_data, rtol=0, atol=0
136 ), "Execution result of " + tflite_model + " does not match with " + circle_model
137 else:
138 assert False, "Unsupported data type: " + output_details["dtype"]
139
140
141# arguments must be in sync with `conftest.py`
142def test_circle_part_value(test_name: str, bin_dir: str, part_driver_path: str):
143 part_eval(test_name, bin_dir, part_driver_path)
part_eval(test_name, bin_dir, circle_part_driver)