47 circle_model_ref = os.path.join(ref_artifacts, test_name +
".circle")
48 circle_model = os.path.join(target_artifacts, test_name +
".circle")
51 rtolint = int(rtolf32)
52 atolint = int(atolf32)
57 input_file_path = circle_model_ref +
".input" + str(check_input)
58 if not os.path.isfile(input_file_path):
59 num_inputs = check_input
61 check_input = check_input + 1
63 assert num_inputs != 0,
"input file not exist for " + circle_model_ref
68 output_file_path = circle_model_ref +
".output" + str(check_output)
69 if not os.path.isfile(output_file_path):
70 num_outputs = check_output
72 check_output = check_output + 1
74 assert num_outputs != 0,
"output file not exist for " + circle_model_ref
78 eval_driver, circle_model_ref,
79 str(num_inputs), circle_model_ref +
".input", circle_model +
".output"
84 for idx
in range(num_outputs):
85 output_dtype =
dtype_from_file(circle_model_ref +
".output" + str(idx) +
".dtype")
86 shape_file = open(circle_model_ref +
".output" + str(idx) +
".shape",
'r')
87 output_shape = [int(i)
for i
in shape_file.read().split(
',')]
89 output_data_ref = np.fromfile(circle_model_ref +
".output" + str(idx),
91 luci_output_data_ref = np.reshape(output_data_ref, output_shape)
93 output_data = np.fromfile(circle_model +
".output" + str(idx), output_dtype)
94 luci_output_data = np.reshape(output_data, output_shape)
96 err_msg =
"Execution result of " + circle_model_ref +
" does not match with " + circle_model
97 if output_dtype == np.uint8:
98 assert np.allclose(luci_output_data,
101 atol=atolint), err_msg
102 elif output_dtype == np.float32:
103 assert np.allclose(luci_output_data,
104 luci_output_data_ref,
106 atol=atolf32), err_msg
107 elif output_dtype == np.int64:
108 assert np.allclose(luci_output_data,
109 luci_output_data_ref,
111 atol=atolint), err_msg
112 elif output_dtype == np.int32:
113 assert np.allclose(luci_output_data,
114 luci_output_data_ref,
116 atol=atolint), err_msg
117 elif output_dtype == np.int16:
118 assert np.allclose(luci_output_data,
119 luci_output_data_ref,
121 atol=atolint), err_msg
122 elif output_dtype == np.bool_:
123 assert np.allclose(luci_output_data, luci_output_data_ref, rtol=0,
126 assert False,
"Unsupported data type: " + output_dtype
138 target_artifacts_path: str, eval_driver_path: str,
139 rtolf32: str, atolf32: str):
141 eval_driver_path, float(rtolf32), float(atolf32))