ONE - On-device Neural Engine
Loading...
Searching...
No Matches
infer.py
Go to the documentation of this file.
1import argparse
2import h5py
3import numpy as np
4from pathlib import Path
5import re
6import sys
7
8
9
10
12 p = re.compile('eval\\((.*)\\)')
13 result = p.search(s)
14 return result.group(1)
15
16
17parser = argparse.ArgumentParser()
18parser.add_argument('--lib_path', type=str, required=True)
19parser.add_argument('--test_list', type=str, required=True)
20parser.add_argument('--artifact_dir', type=str, required=True)
21args = parser.parse_args()
22
23with open(args.test_list) as f:
24 contents = [line.rstrip() for line in f]
25# remove newline and comments.
26eval_lines = [line for line in contents if line.startswith('eval(')]
27test_args = [extract_test_args(line) for line in eval_lines]
28test_models = [Path(args.artifact_dir) / f'{arg}.circle' for arg in test_args]
29input_data = [
30 Path(args.artifact_dir) / f'{arg}.opt/metadata/tc/input.h5' for arg in test_args
31]
32expected_output_data = [
33 Path(args.artifact_dir) / f'{arg}.opt/metadata/tc/expected.h5' for arg in test_args
34]
35
36
37
38from cffi import FFI
39
40ffi = FFI()
41ffi.cdef("""
42 typedef struct InterpreterWrapper InterpreterWrapper;
43
44 const char *get_last_error(void);
45 void clear_last_error(void);
46 InterpreterWrapper *Interpreter_new(const uint8_t *data, const size_t data_size);
47 void Interpreter_delete(InterpreterWrapper *intp);
48 void Interpreter_interpret(InterpreterWrapper *intp);
49 void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data, size_t input_size);
50 void Interpreter_readOutputTensor(InterpreterWrapper *intp, const int output_idx, void *output, size_t output_size);
51""")
52C = ffi.dlopen(args.lib_path)
53
54
56 error_message = ffi.string(C.get_last_error()).decode('utf-8')
57 if error_message:
58 C.clear_last_error()
59 raise RuntimeError(f'C++ Exception: {error_message}')
60
61
62def error_checked(func):
63 """
64 Decorator to wrap functions with error checking.
65 """
66 def wrapper(*args, **kwargs):
67 result = func(*args, **kwargs)
69 return result
70
71 return wrapper
72
73
74Interpreter_new = error_checked(C.Interpreter_new)
75Interpreter_delete = error_checked(C.Interpreter_delete)
76Interpreter_interpret = error_checked(C.Interpreter_interpret)
77Interpreter_writeInputTensor = error_checked(C.Interpreter_writeInputTensor)
78Interpreter_readOutputTensor = error_checked(C.Interpreter_readOutputTensor)
79
80for idx, model_path in enumerate(test_models):
81 with open(model_path, "rb") as f:
82 model_data = ffi.from_buffer(bytearray(f.read()))
83
84 try:
85 intp = Interpreter_new(model_data, len(model_data))
86
87 # Set inputs
88 h5 = h5py.File(input_data[idx])
89 input_values = h5.get('value')
90 input_num = len(input_values)
91 for input_idx in range(input_num):
92 arr = np.array(input_values.get(str(input_idx)))
93 c_arr = ffi.from_buffer(arr)
94 Interpreter_writeInputTensor(intp, input_idx, c_arr, arr.nbytes)
95 # Do inference
97 # Check outputs
98 h5 = h5py.File(expected_output_data[idx])
99 output_values = h5.get('value')
100 output_num = len(output_values)
101 for output_idx in range(output_num):
102 arr = np.array(output_values.get(str(output_idx)))
103 result = np.empty(arr.shape, dtype=arr.dtype)
104 Interpreter_readOutputTensor(intp, output_idx, ffi.from_buffer(result),
105 arr.nbytes)
106 if not np.allclose(result, arr):
107 raise RuntimeError("Wrong outputs")
108
110 except RuntimeError as e:
111 print(e)
112 sys.exit(-1)
check_for_errors()
Definition infer.py:55
Interpreter_interpret
Definition infer.py:76
Interpreter_delete
Definition infer.py:75
error_checked(func)
Definition infer.py:62
Interpreter_readOutputTensor
Definition infer.py:78
extract_test_args(s)
Managing paths for the artifacts required by the test.
Definition infer.py:11
Interpreter_new
Definition infer.py:74
Interpreter_writeInputTensor
Definition infer.py:77
str
Definition infer.py:18