ONE - On-device Neural Engine
Loading...
Searching...
No Matches
eval_result_verifier.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2import numpy as np
3import tensorflow as tf
4import subprocess
5import argparse
6import traceback
7
8#
9# This script was copied from luci-value-test with input arguments are tflite and circle path
10#
11parser = argparse.ArgumentParser()
12parser.add_argument('--driver', type=str, required=True)
13parser.add_argument('--tflite', type=str, required=True)
14parser.add_argument('--circle', type=str, required=True)
15args = parser.parse_args()
16
17driver = args.driver
18tflite_model = args.tflite
19circle_model = args.circle
20
21# Build TFLite interpreter.
22interpreter = tf.lite.Interpreter(tflite_model)
23interpreter.allocate_tensors()
24
25# Read SignatureDef and get output tensor id orders for remapping
26full_signatures = interpreter._get_full_signature_list()
27full_signatures_outputs_remap = None
28if full_signatures != None:
29 signature_serving_default = full_signatures.get('serving_default', None)
30 if signature_serving_default != None:
31 signature_outputs = signature_serving_default['outputs']
32
33 full_signatures_outputs_remap = []
34 for index, (key, value) in enumerate(signature_outputs.items()):
35 full_signatures_outputs_remap.append(value)
36
37# Generate random input data.
38num_inputs = len(interpreter.get_input_details())
39for i in range(num_inputs):
40 input_details = interpreter.get_input_details()[i]
41 if input_details["dtype"] == np.float32:
42 input_data = np.array(np.random.random_sample(input_details["shape"]),
43 input_details["dtype"])
44 elif input_details["dtype"] == np.uint8:
45 input_data = np.array(np.random.randint(0, 256, size=input_details["shape"]),
46 input_details["dtype"])
47 elif input_details["dtype"] == np.int16:
48 input_data = np.array(np.random.randint(0, 100, size=input_details["shape"]),
49 input_details["dtype"])
50 elif input_details["dtype"] == np.bool_:
51 input_data = np.array(
52 np.random.choice(a=[True, False], size=input_details["shape"]),
53 input_details["dtype"])
54 else:
55 raise SystemExit("Unsupported input dtype")
56
57 interpreter.set_tensor(input_details["index"], input_data)
58 input_data.tofile(circle_model + ".input" + str(i))
59
60# Do inference
61interpreter.invoke()
62
63# Execute luci interpreter.
64subprocess.run([
65 driver, circle_model,
66 str(num_inputs), circle_model + ".input", circle_model + ".output"
67],
68 check=True)
69
70# Compare the results.
71inpt_output_details = interpreter.get_output_details()
72for idx in range(len(inpt_output_details)):
73 output_details = inpt_output_details[idx]
74 output_data = np.fromfile(circle_model + ".output" + str(idx),
75 output_details["dtype"])
76 shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r')
77 output_shape = [int(i) for i in shape_file.read().split(',')]
78 luci_output_data = np.reshape(output_data, output_shape)
79 output_tensor = output_details["index"]
80 if full_signatures_outputs_remap != None:
81 output_tensor = full_signatures_outputs_remap[idx]
82 intp_output_data = interpreter.get_tensor(output_tensor)
83 try:
84 if output_details["dtype"] == np.uint8:
85 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
86 raise SystemExit("Execution result of " + tflite_model +
87 " does not match with " + circle_model)
88 elif output_details["dtype"] == np.float32:
89 if np.allclose(luci_output_data, intp_output_data, rtol=1.e-5,
90 atol=1.e-5) == False:
91 raise SystemExit("Execution result of " + tflite_model +
92 " does not match with " + circle_model)
93 elif output_details["dtype"] == np.int64:
94 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
95 raise SystemExit("Execution result of " + tflite_model +
96 " does not match with " + circle_model)
97 elif output_details["dtype"] == np.int32:
98 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
99 raise SystemExit("Execution result of " + tflite_model +
100 " does not match with " + circle_model)
101 elif output_details["dtype"] == np.int16:
102 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
103 raise SystemExit("Execution result of " + tflite_model +
104 " does not match with " + circle_model)
105 else:
106 raise SystemExit("Unsupported data type: ", output_details["dtype"])
107 except:
108 print(traceback.format_exc())
109 quit(255)
110
111quit(0)