11parser = argparse.ArgumentParser()
12parser.add_argument(
'--driver', type=str, required=
True)
13parser.add_argument(
'--tflite', type=str, required=
True)
14parser.add_argument(
'--circle', type=str, required=
True)
15args = parser.parse_args()
18tflite_model = args.tflite
19circle_model = args.circle
22interpreter = tf.lite.Interpreter(tflite_model)
23interpreter.allocate_tensors()
26full_signatures = interpreter._get_full_signature_list()
27full_signatures_outputs_remap =
None
28if full_signatures !=
None:
29 signature_serving_default = full_signatures.get(
'serving_default',
None)
30 if signature_serving_default !=
None:
31 signature_outputs = signature_serving_default[
'outputs']
33 full_signatures_outputs_remap = []
34 for index, (key, value)
in enumerate(signature_outputs.items()):
35 full_signatures_outputs_remap.append(value)
38num_inputs = len(interpreter.get_input_details())
39for i
in range(num_inputs):
40 input_details = interpreter.get_input_details()[i]
41 if input_details[
"dtype"] == np.float32:
42 input_data = np.array(np.random.random_sample(input_details[
"shape"]),
43 input_details[
"dtype"])
44 elif input_details[
"dtype"] == np.uint8:
45 input_data = np.array(np.random.randint(0, 256, size=input_details[
"shape"]),
46 input_details[
"dtype"])
47 elif input_details[
"dtype"] == np.int16:
48 input_data = np.array(np.random.randint(0, 100, size=input_details[
"shape"]),
49 input_details[
"dtype"])
50 elif input_details[
"dtype"] == np.bool_:
51 input_data = np.array(
52 np.random.choice(a=[
True,
False], size=input_details[
"shape"]),
53 input_details[
"dtype"])
55 raise SystemExit(
"Unsupported input dtype")
57 interpreter.set_tensor(input_details[
"index"], input_data)
58 input_data.tofile(circle_model +
".input" +
str(i))
66 str(num_inputs), circle_model +
".input", circle_model +
".output"
71inpt_output_details = interpreter.get_output_details()
72for idx
in range(len(inpt_output_details)):
73 output_details = inpt_output_details[idx]
74 output_data = np.fromfile(circle_model +
".output" +
str(idx),
75 output_details[
"dtype"])
76 shape_file = open(circle_model +
".output" +
str(idx) +
".shape",
'r')
77 output_shape = [int(i)
for i
in shape_file.read().split(
',')]
78 luci_output_data = np.reshape(output_data, output_shape)
79 output_tensor = output_details[
"index"]
80 if full_signatures_outputs_remap !=
None:
81 output_tensor = full_signatures_outputs_remap[idx]
82 intp_output_data = interpreter.get_tensor(output_tensor)
84 if output_details[
"dtype"] == np.uint8:
85 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) ==
False:
86 raise SystemExit(
"Execution result of " + tflite_model +
87 " does not match with " + circle_model)
88 elif output_details[
"dtype"] == np.float32:
89 if np.allclose(luci_output_data, intp_output_data, rtol=1.e-5,
91 raise SystemExit(
"Execution result of " + tflite_model +
92 " does not match with " + circle_model)
93 elif output_details[
"dtype"] == np.int64:
94 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) ==
False:
95 raise SystemExit(
"Execution result of " + tflite_model +
96 " does not match with " + circle_model)
97 elif output_details[
"dtype"] == np.int32:
98 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) ==
False:
99 raise SystemExit(
"Execution result of " + tflite_model +
100 " does not match with " + circle_model)
101 elif output_details[
"dtype"] == np.int16:
102 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) ==
False:
103 raise SystemExit(
"Execution result of " + tflite_model +
104 " does not match with " + circle_model)
106 raise SystemExit(
"Unsupported data type: ", output_details[
"dtype"])
108 print(traceback.format_exc())