ONE - On-device Neural Engine
Loading...
Searching...
No Matches
part_eval_one.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2import numpy as np
3import tensorflow as tf
4import subprocess
5import argparse
6import traceback
7import json
8
9#
10# This script compares the execution result of TFLite interpreter and
11# partitioned model(s) from a circle model
12#
13# Basic usage for example:
14# part_eval_one.py \
15# --driver build/compiler/circle-part-driver/circle-part-driver \
16# --name test_file
17#
18parser = argparse.ArgumentParser()
19parser.add_argument('--driver', type=str, required=True)
20parser.add_argument('--name', type=str, required=True)
21args = parser.parse_args()
22
23driver = args.driver
24tflite_model = args.name + ".tflite"
25circle_model = args.name + ".circle"
26partition_conn_ini = args.name + ".conn.ini"
27partition_conn_json = args.name + ".conn.json"
28expected_count = args.name + ".excnt"
29
30# Check expected count of models from partitioning
31try:
32 with open(expected_count, "r") as expected_count_file:
33 expected_count_line = expected_count_file.readline()
34
35 expected_count_line = int(expected_count_line)
36 if expected_count_line:
37 with open(partition_conn_json) as json_file:
38 json_data = json.load(json_file)
39 parts_value = json_data["parts"]
40 if len(parts_value) != expected_count_line:
41 print("Partitioned model count differs from expected:",
42 expected_count_line)
43 quit(255)
44
45 print("Partitioned model count expected: ", expected_count_line)
46 else:
47 print("Skip expected partitioned model count check: 0")
48
49except:
50 print("Skip expected partitioned model count check: error")
51
52# Build TFLite interpreter.
53interpreter = tf.lite.Interpreter(tflite_model)
54interpreter.allocate_tensors()
55
56# Read SignatureDef and get output tensor id orders for remapping
57full_signatures = interpreter._get_full_signature_list()
58full_signatures_outputs_remap = None
59if full_signatures != None:
60 signature_serving_default = full_signatures.get('serving_default', None)
61 if signature_serving_default != None:
62 signature_outputs = signature_serving_default['outputs']
63
64 full_signatures_outputs_remap = []
65 for index, (key, value) in enumerate(signature_outputs.items()):
66 full_signatures_outputs_remap.append(value)
67
68# Generate random input data.
69num_inputs = len(interpreter.get_input_details())
70for i in range(num_inputs):
71 input_details = interpreter.get_input_details()[i]
72 input_details_dtype = input_details["dtype"]
73 input_details_shape = input_details["shape"]
74 if input_details_dtype == np.float32:
75 input_data = np.array(np.random.random_sample(input_details_shape),
76 input_details_dtype)
77 elif input_details_dtype == np.int16:
78 input_data = np.array(np.random.randint(0, 100, size=input_details_shape),
79 input_details_dtype)
80 elif input_details_dtype == np.uint8:
81 input_data = np.array(np.random.randint(0, 256, size=input_details_shape),
82 input_details_dtype)
83 elif input_details_dtype == np.bool_:
84 input_data = np.array(np.random.choice(a=[True, False], size=input_details_shape),
85 input_details_dtype)
86 else:
87 raise SystemExit("Unsupported input dtype")
88
89 interpreter.set_tensor(input_details["index"], input_data)
90 input_data.tofile(circle_model + ".input" + str(i))
91
92# Do inference
93interpreter.invoke()
94
95# Execute circle-part-driver.
96partition_command = [
97 driver, partition_conn_ini,
98 str(num_inputs), circle_model + ".input", circle_model + ".output"
99]
100print("Run: ")
101for arg in partition_command:
102 print(" ", arg, "\\")
103print("", flush=True)
104
105subprocess.run(partition_command, check=True)
106
107# Compare the results.
108inpt_output_details = interpreter.get_output_details()
109for idx in range(len(inpt_output_details)):
110 output_details = inpt_output_details[idx]
111 output_dtype = output_details["dtype"]
112 output_data = np.fromfile(circle_model + ".output" + str(idx), output_dtype)
113 shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r')
114 output_shape = [int(i) for i in shape_file.read().split(',')]
115 luci_output_data = np.reshape(output_data, output_shape)
116 output_tensor = output_details["index"]
117 if full_signatures_outputs_remap != None:
118 output_tensor = full_signatures_outputs_remap[idx]
119 intp_output_data = interpreter.get_tensor(output_tensor)
120 try:
121 if output_dtype == np.uint8:
122 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
123 raise SystemExit("Execution result of " + tflite_model +
124 " does not match with " + circle_model)
125 elif output_dtype == np.float32:
126 if np.allclose(luci_output_data, intp_output_data, rtol=1.e-5,
127 atol=1.e-5) == False:
128 raise SystemExit("Execution result of " + tflite_model +
129 " does not match with " + circle_model)
130 elif output_dtype == np.int64:
131 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
132 raise SystemExit("Execution result of " + tflite_model +
133 " does not match with " + circle_model)
134 elif output_dtype == np.int32:
135 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
136 raise SystemExit("Execution result of " + tflite_model +
137 " does not match with " + circle_model)
138 elif output_dtype == np.int16:
139 if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
140 raise SystemExit("Execution result of " + tflite_model +
141 " does not match with " + circle_model)
142 else:
143 raise SystemExit("Unsupported data type: ", output_dtype)
144 except:
145 print(traceback.format_exc())
146 quit(255)
147
148quit(0)