ONE - On-device Neural Engine
Loading...
Searching...
No Matches
basesession.py
Go to the documentation of this file.
1from typing import List
2import numpy as np
3
4from ..native.libnnfw_api_pybind import infer, tensorinfo
5from ..native.libnnfw_api_pybind.exception import OnertError
6
7
8def num_elems(tensor_info):
9 """Get the total number of elements in nnfw_tensorinfo.dims."""
10 n = 1
11 for x in range(tensor_info.rank):
12 n *= tensor_info.dims[x]
13 return n
14
15
17 """
18 Base class providing common functionality for inference and training sessions.
19 """
20 def __init__(self, backend_session=None):
21 """
22 Initialize the BaseSession with a backend session.
23 Args:
24 backend_session: A backend-specific session object (e.g., nnfw_session).
25 """
26 self.session = backend_session
27 self.inputs = []
28 self.outputs = []
29
30 def __getattr__(self, name):
31 """
32 Delegate attribute access to the bound NNFW_SESSION instance.
33 Args:
34 name (str): The name of the attribute or method to access.
35 Returns:
36 The attribute or method from the bound NNFW_SESSION instance.
37 """
38 if name in self.__dict__:
39 # First, try to get the attribute from the instance's own dictionary
40 return self.__dict__[name]
41 elif hasattr(self.session, name):
42 # If not found, delegate to the session object
43 return getattr(self.session, name)
44 else:
45 raise AttributeError(
46 f"'{type(self).__name__}' object has no attribute '{name}'")
47
48 def _recreate_session(self, backend_session):
49 """
50 Protected method to recreate the session.
51 Subclasses can override this method to provide custom session recreation logic.
52 """
53 if self.session is not None:
54 del self.session # Clean up the existing session
55 self.session = backend_session
56
57 def get_inputs_tensorinfo(self) -> List[tensorinfo]:
58 """
59 Retrieve tensorinfo for all input tensors.
60
61 Raises:
62 OnertError: If the underlying C-API call fails.
63
64 Returns:
65 list[tensorinfo]: A list of tensorinfo objects for each input.
66 """
67 num_inputs: int = self.session.input_size()
68 infos: List[tensorinfo] = []
69 for i in range(num_inputs):
70 try:
71 infos.append(self.session.input_tensorinfo(i))
72 except ValueError:
73 raise
74 except Exception as e:
75 raise OnertError(f"Failed to get input tensorinfo #{i}: {e}") from e
76 return infos
77
78 def get_outputs_tensorinfo(self) -> List[tensorinfo]:
79 """
80 Retrieve tensorinfo for all output tensors.
81
82 Raises:
83 OnertError: If the underlying C-API call fails.
84
85 Returns:
86 list[tensorinfo]: A list of tensorinfo objects for each output.
87 """
88 num_outputs: int = self.session.output_size()
89 infos: List[tensorinfo] = []
90 for i in range(num_outputs):
91 try:
92 infos.append(self.session.output_tensorinfo(i))
93 except ValueError:
94 raise
95 except Exception as e:
96 raise OnertError(f"Failed to get output tensorinfo #{i}: {e}") from e
97 return infos
98
99 def set_inputs(self, size, inputs_array=[]):
100 """
101 Set the input tensors for the session.
102
103 Args:
104 size (int): Number of input tensors.
105 inputs_array (list): List of numpy arrays for the input data.
106
107 Raises:
108 ValueError: If session uninitialized.
109 OnertError: If any C-API call fails.
110 """
111 if self.session is None:
112 raise ValueError(
113 "Session is not initialized with a model. Please compile with a model before setting inputs."
114 )
115
116 self.inputs = []
117 for i in range(size):
118 try:
119 input_tensorinfo = self.session.input_tensorinfo(i)
120 except ValueError:
121 raise
122 except Exception as e:
123 raise OnertError(f"Failed to get input tensorinfo #{i}: {e}") from e
124
125 if len(inputs_array) > i:
126 input_array = np.array(inputs_array[i], dtype=input_tensorinfo.dtype)
127 else:
128 print(
129 f"Model's input size is {size}, but given inputs_array size is {len(inputs_array)}.\n{i}-th index input is replaced by an array filled with 0."
130 )
131 input_array = np.zeros((num_elems(input_tensorinfo)),
132 dtype=input_tensorinfo.dtype)
133
134 # Check if the shape of input_array matches the dims of input_tensorinfo
135 if input_array.shape != tuple(input_tensorinfo.dims):
136 # If not, set the input tensor info to match the input_array shape
137 try:
138 input_tensorinfo.rank = len(input_array)
139 input_tensorinfo.dims = list(input_array.shape)
140 self.session.set_input_tensorinfo(i, input_tensorinfo)
141 except Exception as e:
142 raise OnertError(f"Failed to set input tensor info #{i}: {e}") from e
143
144 try:
145 self.session.set_input(i, input_array)
146 except ValueError:
147 raise
148 except Exception as e:
149 raise OnertError(f"Failed to set input #{i}: {e}") from e
150
151 self.inputs.append(input_array)
152
153 def _set_outputs(self, size):
154 """
155 Set the output tensors for the session.
156
157 Args:
158 size (int): Number of output tensors.
159
160 Raises:
161 ValueError: If session uninitialized.
162 OnertError: If any C-API call fails.
163 """
164 if self.session is None:
165 raise ValueError(
166 "Session is not initialized with a model. Please compile a model before setting outputs."
167 )
168
169 self.outputs = []
170 for i in range(size):
171 try:
172 output_array = self.session.get_output(i)
173 except ValueError:
174 raise
175 except Exception as e:
176 raise OnertError(f"Failed to get output #{i}: {e}") from e
177
178 self.outputs.append(output_array)
179
180
181def tensorinfo():
182 """
183 Shortcut to create a fresh tensorinfo instance.
184 Raises:
185 OnertError: If the C-API call fails.
186 """
187
188 try:
189 return infer.nnfw_tensorinfo()
190 except OnertError:
191 raise
192 except Exception as e:
193 raise OnertError(f"Failed to create tensorinfo: {e}") from e
set_inputs(self, size, inputs_array=[])
_recreate_session(self, backend_session)
__init__(self, backend_session=None)
List[tensorinfo] get_inputs_tensorinfo(self)
List[tensorinfo] get_outputs_tensorinfo(self)