ONE - On-device Neural Engine
Loading...
Searching...
No Matches
test_luci_eval_ref.py
Go to the documentation of this file.
1import numpy as np
2import tensorflow as tf
3import subprocess
4import os
5
6#
7# This script compares the execution result of luci-interpreter with that from ref_model path
8#
9# Basic usage:
10# luci_eval_verifier_ref.py --driver build/compiler/luci-eval-driver/luci_eval_driver
11# --ref_model ref_model_path --model this_model_path
12# Assumption:
13# these file exist with its purpose
14# - ref_model_path.circle; circle model
15# - ref_model_path.circle.inputN; N'th input numpy data
16# - ref_model_path.circle.inputN.dtype; N'th input data type in text
17# - ref_model_path.circle.inputN.shape; N'th input data shape in CSV
18# - ref_model_path.circle.outputN; N'th output numpy data
19# - ref_model_path.circle.outputN.dtype; N'th output data type in text
20# - ref_model_path.circle.outputN.shape; N'th output data shape in CSV
21
22
23def dtype_from_file(file_path):
24 with open(file_path, 'r') as dtype_file:
25 dtype_str = dtype_file.read()
26 if dtype_str == "float32":
27 return np.float32
28 if dtype_str == "uint8":
29 return np.uint8
30 if dtype_str == "int16":
31 return np.int16
32 if dtype_str == "int32":
33 return np.int32
34 if dtype_str == "int64":
35 return np.int64
36 if dtype_str == "bool":
37 return np.bool_
38 assert False, "Unsupported dtype from file: " + dtype_str
39
40
42 ref_artifacts,
43 target_artifacts,
44 eval_driver,
45 rtolf32=1e-5,
46 atolf32=1e-5):
47 circle_model_ref = os.path.join(ref_artifacts, test_name + ".circle")
48 circle_model = os.path.join(target_artifacts, test_name + ".circle")
49
50 # NOTE reuse f32 value as int value too
51 rtolint = int(rtolf32)
52 atolint = int(atolf32)
53
54 # get num of inputs by checking existance of model.inputN
55 check_input = 0
56 while True:
57 input_file_path = circle_model_ref + ".input" + str(check_input)
58 if not os.path.isfile(input_file_path):
59 num_inputs = check_input
60 break
61 check_input = check_input + 1
62
63 assert num_inputs != 0, "input file not exist for " + circle_model_ref
64
65 # get num of outputs by checking existance of model.outputN
66 check_output = 0
67 while True:
68 output_file_path = circle_model_ref + ".output" + str(check_output)
69 if not os.path.isfile(output_file_path):
70 num_outputs = check_output
71 break
72 check_output = check_output + 1
73
74 assert num_outputs != 0, "output file not exist for " + circle_model_ref
75
76 # Execute luci interpreter with reference input
77 subprocess.run([
78 eval_driver, circle_model_ref,
79 str(num_inputs), circle_model_ref + ".input", circle_model + ".output"
80 ],
81 check=True)
82
83 # Compare the results.
84 for idx in range(num_outputs):
85 output_dtype = dtype_from_file(circle_model_ref + ".output" + str(idx) + ".dtype")
86 shape_file = open(circle_model_ref + ".output" + str(idx) + ".shape", 'r')
87 output_shape = [int(i) for i in shape_file.read().split(',')]
88
89 output_data_ref = np.fromfile(circle_model_ref + ".output" + str(idx),
90 output_dtype)
91 luci_output_data_ref = np.reshape(output_data_ref, output_shape)
92
93 output_data = np.fromfile(circle_model + ".output" + str(idx), output_dtype)
94 luci_output_data = np.reshape(output_data, output_shape)
95
96 err_msg = "Execution result of " + circle_model_ref + " does not match with " + circle_model
97 if output_dtype == np.uint8:
98 assert np.allclose(luci_output_data,
99 luci_output_data_ref,
100 rtol=rtolint,
101 atol=atolint), err_msg
102 elif output_dtype == np.float32:
103 assert np.allclose(luci_output_data,
104 luci_output_data_ref,
105 rtol=rtolf32,
106 atol=atolf32), err_msg
107 elif output_dtype == np.int64:
108 assert np.allclose(luci_output_data,
109 luci_output_data_ref,
110 rtol=rtolint,
111 atol=atolint), err_msg
112 elif output_dtype == np.int32:
113 assert np.allclose(luci_output_data,
114 luci_output_data_ref,
115 rtol=rtolint,
116 atol=atolint), err_msg
117 elif output_dtype == np.int16:
118 assert np.allclose(luci_output_data,
119 luci_output_data_ref,
120 rtol=rtolint,
121 atol=atolint), err_msg
122 elif output_dtype == np.bool_:
123 assert np.allclose(luci_output_data, luci_output_data_ref, rtol=0,
124 atol=0), err_msg
125 else:
126 assert False, "Unsupported data type: " + output_dtype
127
128
129# arguments must be in sync with `conftest.py`
130def test_luci_eval_ref(default_ref_test_name: str, ref_artifacts_path: str,
131 target_artifacts_path: str, eval_driver_path: str):
132 luci_eval_verify_ref(default_ref_test_name, ref_artifacts_path, target_artifacts_path,
133 eval_driver_path)
134
135
136# arguments must be in sync with `conftest.py`
137def test_luci_eval_tol_ref(tol_ref_test_name: str, ref_artifacts_path: str,
138 target_artifacts_path: str, eval_driver_path: str,
139 rtolf32: str, atolf32: str):
140 luci_eval_verify_ref(tol_ref_test_name, ref_artifacts_path, target_artifacts_path,
141 eval_driver_path, float(rtolf32), float(atolf32))
luci_eval_verify_ref(test_name, ref_artifacts, target_artifacts, eval_driver, rtolf32=1e-5, atolf32=1e-5)
test_luci_eval_tol_ref(str tol_ref_test_name, str ref_artifacts_path, str target_artifacts_path, str eval_driver_path, str rtolf32, str atolf32)