4from pathlib
import Path
12 p = re.compile(
'eval\\((.*)\\)')
14 return result.group(1)
17parser = argparse.ArgumentParser()
18parser.add_argument(
'--lib_path', type=str, required=
True)
19parser.add_argument(
'--test_list', type=str, required=
True)
20parser.add_argument(
'--artifact_dir', type=str, required=
True)
21args = parser.parse_args()
23with open(args.test_list)
as f:
24 contents = [line.rstrip()
for line
in f]
26eval_lines = [line
for line
in contents
if line.startswith(
'eval(')]
28test_models = [Path(args.artifact_dir) / f
'{arg}.circle' for arg
in test_args]
30 Path(args.artifact_dir) / f
'{arg}.opt/metadata/tc/input.h5' for arg
in test_args
32expected_output_data = [
33 Path(args.artifact_dir) / f
'{arg}.opt/metadata/tc/expected.h5' for arg
in test_args
42 typedef struct InterpreterWrapper InterpreterWrapper;
44 const char *get_last_error(void);
45 void clear_last_error(void);
46 InterpreterWrapper *Interpreter_new(const uint8_t *data, const size_t data_size);
47 void Interpreter_delete(InterpreterWrapper *intp);
48 void Interpreter_interpret(InterpreterWrapper *intp);
49 void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data, size_t input_size);
50 void Interpreter_readOutputTensor(InterpreterWrapper *intp, const int output_idx, void *output, size_t output_size);
52C = ffi.dlopen(args.lib_path)
56 error_message = ffi.string(C.get_last_error()).decode(
'utf-8')
59 raise RuntimeError(f
'C++ Exception: {error_message}')
64 Decorator to wrap functions with error checking.
66 def wrapper(*args, **kwargs):
67 result = func(*args, **kwargs)
77Interpreter_writeInputTensor =
error_checked(C.Interpreter_writeInputTensor)
78Interpreter_readOutputTensor =
error_checked(C.Interpreter_readOutputTensor)
80for idx, model_path
in enumerate(test_models):
81 with open(model_path,
"rb")
as f:
82 model_data = ffi.from_buffer(bytearray(f.read()))
88 h5 = h5py.File(input_data[idx])
89 input_values = h5.get(
'value')
90 input_num = len(input_values)
91 for input_idx
in range(input_num):
92 arr = np.array(input_values.get(
str(input_idx)))
93 c_arr = ffi.from_buffer(arr)
98 h5 = h5py.File(expected_output_data[idx])
99 output_values = h5.get(
'value')
100 output_num = len(output_values)
101 for output_idx
in range(output_num):
102 arr = np.array(output_values.get(
str(output_idx)))
103 result = np.empty(arr.shape, dtype=arr.dtype)
106 if not np.allclose(result, arr):
107 raise RuntimeError(
"Wrong outputs")
110 except RuntimeError
as e:
Interpreter_readOutputTensor
extract_test_args(s)
Managing paths for the artifacts required by the test.
Interpreter_writeInputTensor