ONE - On-device Neural Engine
Loading...
Searching...
No Matches
test_luci_eval.py
Go to the documentation of this file.
1import numpy as np
2import tensorflow as tf
3import subprocess
4import os
5
6
7def luci_eval_verify(test_name,
8 tflite_dir,
9 circle_dir,
10 eval_driver,
11 rtolf32=1e-5,
12 atolf32=1e-5):
13 tflite_model = os.path.join(tflite_dir, test_name + ".tflite")
14 circle_model = os.path.join(circle_dir, test_name + ".pass.circle")
15
16 # NOTE reuse f32 value as int value too
17 rtolint = int(rtolf32)
18 atolint = int(atolf32)
19
20 # Build TFLite interpreter.
21 interpreter = tf.lite.Interpreter(tflite_model)
22 interpreter.allocate_tensors()
23
24 # Read SignatureDef and get output tensor id orders for remapping
25 full_signatures = interpreter._get_full_signature_list()
26 full_signatures_outputs_remap = None
27 if full_signatures != None:
28 signature_serving_default = full_signatures.get('serving_default', None)
29 if signature_serving_default != None:
30 signature_outputs = signature_serving_default['outputs']
31
32 full_signatures_outputs_remap = []
33 for index, (key, value) in enumerate(signature_outputs.items()):
34 full_signatures_outputs_remap.append(value)
35
36 # Generate random input data.
37 num_inputs = len(interpreter.get_input_details())
38 for i in range(num_inputs):
39 input_details = interpreter.get_input_details()[i]
40 if input_details["dtype"] == np.float32:
41 input_data = np.array(np.random.random_sample(input_details["shape"]),
42 input_details["dtype"])
43 elif input_details["dtype"] == np.uint8:
44 input_data = np.array(np.random.randint(0, 256, size=input_details["shape"]),
45 input_details["dtype"])
46 elif input_details["dtype"] == np.int16:
47 input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]),
48 input_details["dtype"])
49 elif input_details["dtype"] == np.int32:
50 input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]),
51 input_details["dtype"])
52 elif input_details["dtype"] == np.int64:
53 input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]),
54 input_details["dtype"])
55 elif input_details["dtype"] == np.bool_:
56 input_data = np.array(
57 np.random.choice(a=[True, False], size=input_details["shape"]),
58 input_details["dtype"])
59 else:
60 assert False, "Unsupported input dtype"
61
62 interpreter.set_tensor(input_details["index"], input_data)
63 input_data.tofile(circle_model + ".input" + str(i))
64
65 # Do inference
66 interpreter.invoke()
67
68 # Execute luci interpreter.
69 subprocess.run([
70 eval_driver, circle_model,
71 str(num_inputs), circle_model + ".input", circle_model + ".output"
72 ],
73 check=True)
74
75 # Compare the results.
76 inpt_output_details = interpreter.get_output_details()
77 for idx in range(len(inpt_output_details)):
78 output_details = inpt_output_details[idx]
79 output_data = np.fromfile(circle_model + ".output" + str(idx),
80 output_details["dtype"])
81 shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r')
82 output_shape = [int(i) for i in shape_file.read().split(',')]
83 luci_output_data = np.reshape(output_data, output_shape)
84 output_tensor = output_details["index"]
85 if full_signatures_outputs_remap != None:
86 output_tensor = full_signatures_outputs_remap[idx]
87 intp_output_data = interpreter.get_tensor(output_tensor)
88 err_msg = "Execution result of " + tflite_model + " does not match with " + circle_model
89 if output_details["dtype"] == np.uint8:
90 assert np.allclose(luci_output_data,
91 intp_output_data,
92 rtol=rtolint,
93 atol=atolint), err_msg
94 elif output_details["dtype"] == np.float32:
95 assert np.allclose(luci_output_data,
96 intp_output_data,
97 rtol=rtolf32,
98 atol=atolf32), err_msg
99 elif output_details["dtype"] == np.int64:
100 assert np.allclose(luci_output_data,
101 intp_output_data,
102 rtol=rtolint,
103 atol=atolint), err_msg
104 elif output_details["dtype"] == np.int32:
105 assert np.allclose(luci_output_data,
106 intp_output_data,
107 rtol=rtolint,
108 atol=atolint), err_msg
109 elif output_details["dtype"] == np.int16:
110 assert np.allclose(luci_output_data,
111 intp_output_data,
112 rtol=rtolint,
113 atol=atolint), err_msg
114 elif output_details["dtype"] == np.bool_:
115 assert np.allclose(luci_output_data, intp_output_data, rtol=0,
116 atol=0), err_msg
117 else:
118 assert False, "Unsupported data type: " + output_details["dtype"]
119
120
121# arguments must be in sync with `conftest.py`
122def test_luci_eval(test_name: str, tflite_dir: str, circle_dir: str,
123 eval_driver_path: str):
124 luci_eval_verify(test_name, tflite_dir, circle_dir, eval_driver_path)