ONE - On-device Neural Engine
Loading...
Searching...
No Matches
session.py
Go to the documentation of this file.
1from typing import List, Union, Tuple, Dict
2import numpy as np
3import time
4import warnings
5from contextlib import contextmanager
6
7from ..native.libnnfw_api_pybind import infer, prepare_config, tensorinfo
8from ..native.libnnfw_api_pybind.exception import OnertError
9from ..common.basesession import BaseSession
10
11
13 """
14 Class for inference using nnfw_session.
15 """
16 def __init__(self, path: str, backends: str = "cpu") -> None:
17 """
18 Initialize the inference session.
19
20 Args:
21 path (str): Path to the model file or nnpackage directory.
22 backends (str): Backends to use, default is "cpu".
23 """
24 super().__init__(infer.nnfw_session(path, backends))
25 self._prepared: bool = False
26
27 def infer(
28 self,
29 inputs_array: List[np.ndarray],
30 *,
31 measure: bool = False
32 ) -> Union[List[np.ndarray], Tuple[List[np.ndarray], Dict[str, float]]]:
33 """
34 Run a complete inference cycle:
35 - If the session has not been prepared or outputs have not been set, call prepare().
36 - Automatically configure input buffers based on the provided numpy arrays.
37 - Execute the inference session.
38 - Return the output tensors with proper multi-dimensional shapes.
39
40 This method supports dynamic shape modification:
41 - The input shapes can be adjusted dynamically.
42
43 Args:
44 inputs_array (list[np.ndarray]): List of numpy arrays representing the input data.
45 measure (bool): If True, measure prepare/io/run latencies (ms).
46
47 Returns:
48 list[np.ndarray]: A list containing the output numpy arrays.
49 OR
50 (outputs, metrics): Tuple where metrics is a dict with keys
51 'prepare_time_ms', 'io_time_ms', 'run_time_ms'
52 """
53 metrics: Dict[str, float] = {}
54
55 # Verify that the number of provided inputs matches the session's expected input count.
56 expected_input_size: int = self.session.input_size()
57 if len(inputs_array) != expected_input_size:
58 raise ValueError(
59 f"Expected {expected_input_size} input(s), but received {len(inputs_array)}."
60 )
61
62 # Check if the session is prepared. If not, call prepare() once.
63 if not self._prepared:
64 try:
65 with self._time_block(metrics, 'prepare_time_ms', measure):
66 # On first call, fix any -1 dims to real input shapes and validate
67 original_infos = self.get_inputs_tensorinfo()
68 fixed_infos = []
69 for idx, info in enumerate(original_infos):
70 input_shape = inputs_array[idx].shape
71 new_dims = []
72 static_dim_changed = False
73 # only the first `info.rank` entries matter
74 for j, d in enumerate(info.dims[:info.rank]):
75 if d == -1:
76 # replace dynamic dim with actual incoming shape
77 new_dims.append(input_shape[j])
78 elif d == input_shape[j]:
79 # static dim must match the provided array
80 new_dims.append(d)
81 else:
82 static_dim_changed = True
83
84 if static_dim_changed:
85 warnings.warn(
86 f"infer() called with input {idx}'s shape={input_shape}, "
87 f"which differs from model's expected shape={tuple(info.dims)}. "
88 "Ensure this is intended.", UserWarning)
89
90 info.dims = new_dims
91 fixed_infos.append(info)
92
93 # Update tensorinfo to optimize using it
94 self._update_inputs_tensorinfo(fixed_infos)
95
96 self.session.set_prepare_config(
97 prepare_config.ENABLE_INTERNAL_OUTPUT_ALLOC)
98 self.session.prepare()
99 self._prepared = True
100 except ValueError:
101 raise
102 except Exception as e:
103 raise OnertError(f"Session preparation failed: {e}") from e
104
105 # Configure input buffers using the current session's input size and provided data.
106 try:
107 with self._time_block(metrics, 'input_time_ms', measure):
108 self.set_inputs(expected_input_size, inputs_array)
109 except ValueError:
110 raise
111 except Exception as e:
112 raise OnertError(f"Failed to bind inputs: {e}") from e
113
114 # Execute the inference.
115 try:
116 with self._time_block(metrics, 'run_time_ms', measure):
117 self.session.run()
118 except ValueError:
119 raise
120 except Exception as e:
121 raise OnertError(f"Inference execution failed: {e}") from e
122
123 try:
124 with self._time_block(metrics, 'output_time_ms', measure):
125 self._set_outputs(self.session.output_size())
126 except ValueError:
127 raise
128 except Exception as e:
129 raise OnertError(f"Failed to bind outputs: {e}") from e
130
131 # Return the output buffers.
132 return (self.outputsoutputs, metrics) if measure else self.outputsoutputs
133
134 def _update_inputs_tensorinfo(self, new_infos: List[tensorinfo]) -> None:
135 """
136 Update all input tensors' tensorinfo at once.
137
138 Args:
139 new_infos (list[tensorinfo]): A list of updated tensorinfo objects for the inputs.
140
141 Raises:
142 ValueError: If the number of new_infos does not match the session's input size,
143 or if any tensorinfo contains a negative dimension.
144
145 OnertError: If the underlying C-API call fails.
146 """
147 num_inputs: int = self.session.input_size()
148 if len(new_infos) != num_inputs:
149 raise ValueError(
150 f"Expected {num_inputs} input tensorinfo(s), but got {len(new_infos)}.")
151
152 for i, info in enumerate(new_infos):
153 # Check for any negative dimension in the specified rank
154 if any(d < 0 for d in info.dims[:info.rank]):
155 raise ValueError(
156 f"Input tensorinfo at index {i} contains negative dimension(s): "
157 f"{info.dims[:info.rank]}")
158 try:
159 self.session.set_input_tensorinfo(i, info)
160 except ValueError:
161 # re-raise ValueError directly
162 raise
163 except Exception as e:
164 raise OnertError(f"Failed to update input tensorinfo: {e}") from e
165
166 @contextmanager
167 def _time_block(self, metrics: Dict[str, float], key: str, measure: bool):
168 if measure:
169 start = time.perf_counter()
170 yield
171 metrics[key] = (time.perf_counter() - start) * 1000
172 else:
173 yield
void run(std::ofstream &os, const circle::Model *model)
set_inputs(self, size, inputs_array=[])
List[tensorinfo] get_inputs_tensorinfo(self)
None _update_inputs_tensorinfo(self, List[tensorinfo] new_infos)
Definition session.py:134
_time_block(self, Dict[str, float] metrics, str key, bool measure)
Definition session.py:167
None __init__(self, str path, str backends="cpu")
Definition session.py:16
Definition infer.py:1