ONE - On-device Neural Engine
Loading...
Searching...
No Matches
package.common.basesession.BaseSession Class Reference

Public Member Functions

 __init__ (self, backend_session=None)
 
 __getattr__ (self, name)
 
List[tensorinfo] get_inputs_tensorinfo (self)
 
List[tensorinfo] get_outputs_tensorinfo (self)
 
 set_inputs (self, size, inputs_array=[])
 

Data Fields

 session
 
 inputs
 
 outputs
 

Protected Member Functions

 _recreate_session (self, backend_session)
 
 _set_outputs (self, size)
 

Detailed Description

Base class providing common functionality for inference and training sessions.

Definition at line 16 of file basesession.py.

Constructor & Destructor Documentation

◆ __init__()

package.common.basesession.BaseSession.__init__ (   self,
  backend_session = None 
)
Initialize the BaseSession with a backend session.
Args:
    backend_session: A backend-specific session object (e.g., nnfw_session).

Reimplemented in package.infer.session.session.

Definition at line 20 of file basesession.py.

20 def __init__(self, backend_session=None):
21 """
22 Initialize the BaseSession with a backend session.
23 Args:
24 backend_session: A backend-specific session object (e.g., nnfw_session).
25 """
26 self.session = backend_session
27 self.inputs = []
28 self.outputs = []
29

Member Function Documentation

◆ __getattr__()

package.common.basesession.BaseSession.__getattr__ (   self,
  name 
)
Delegate attribute access to the bound NNFW_SESSION instance.
Args:
    name (str): The name of the attribute or method to access.
Returns:
    The attribute or method from the bound NNFW_SESSION instance.

Definition at line 30 of file basesession.py.

30 def __getattr__(self, name):
31 """
32 Delegate attribute access to the bound NNFW_SESSION instance.
33 Args:
34 name (str): The name of the attribute or method to access.
35 Returns:
36 The attribute or method from the bound NNFW_SESSION instance.
37 """
38 if name in self.__dict__:
39 # First, try to get the attribute from the instance's own dictionary
40 return self.__dict__[name]
41 elif hasattr(self.session, name):
42 # If not found, delegate to the session object
43 return getattr(self.session, name)
44 else:
45 raise AttributeError(
46 f"'{type(self).__name__}' object has no attribute '{name}'")
47

References validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, and package.common.basesession.BaseSession.session.

◆ _recreate_session()

package.common.basesession.BaseSession._recreate_session (   self,
  backend_session 
)
protected
Protected method to recreate the session.
Subclasses can override this method to provide custom session recreation logic.

Definition at line 48 of file basesession.py.

48 def _recreate_session(self, backend_session):
49 """
50 Protected method to recreate the session.
51 Subclasses can override this method to provide custom session recreation logic.
52 """
53 if self.session is not None:
54 del self.session # Clean up the existing session
55 self.session = backend_session
56

References validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, and package.common.basesession.BaseSession.session.

◆ _set_outputs()

package.common.basesession.BaseSession._set_outputs (   self,
  size 
)
protected
Set the output tensors for the session.

Args:
    size (int): Number of output tensors.

        Raises:
    ValueError: If session uninitialized.
    OnertError:  If any C-API call fails.

Definition at line 153 of file basesession.py.

153 def _set_outputs(self, size):
154 """
155 Set the output tensors for the session.
156
157 Args:
158 size (int): Number of output tensors.
159
160 Raises:
161 ValueError: If session uninitialized.
162 OnertError: If any C-API call fails.
163 """
164 if self.session is None:
165 raise ValueError(
166 "Session is not initialized with a model. Please compile a model before setting outputs."
167 )
168
169 self.outputs = []
170 for i in range(size):
171 try:
172 output_array = self.session.get_output(i)
173 except ValueError:
174 raise
175 except Exception as e:
176 raise OnertError(f"Failed to get output #{i}: {e}") from e
177
178 self.outputs.append(output_array)
179
180

References Operation.outputs, Request.outputs, circlechef::CircleImport.outputs(), crew::Part.outputs, luci::CircleReader.outputs(), luci::PGroup.outputs, mio::circle::Reader.outputs(), moco::ModelSignature.outputs(), nnc::sir::CallFunction.outputs, validate_onnx2circle.OnnxRunner.outputs, tflinspect::Reader.outputs(), tflchef::TFliteImport.outputs(), tflread::Reader.outputs(), luci_interpreter::CircleReader.outputs(), onert_micro::core::reader::OMCircleReader.outputs(), mio::circle::Reader.outputs(), loco::Graph.outputs(), nnkit::support::onnx::Runner.outputs(), ann::Operation.outputs(), loco::Graph.outputs(), onert_micro::execute::OMRuntimeKernel.outputs, nnfw_custom_kernel_params.outputs, package.common.basesession.BaseSession.outputs, package.infer.session.session.outputs, onert::exec::IODescription.outputs, validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, and package.common.basesession.BaseSession.session.

◆ get_inputs_tensorinfo()

List[tensorinfo] package.common.basesession.BaseSession.get_inputs_tensorinfo (   self)
Retrieve tensorinfo for all input tensors.

Raises:
    OnertError: If the underlying C-API call fails.

Returns:
    list[tensorinfo]: A list of tensorinfo objects for each input.

Definition at line 57 of file basesession.py.

57 def get_inputs_tensorinfo(self) -> List[tensorinfo]:
58 """
59 Retrieve tensorinfo for all input tensors.
60
61 Raises:
62 OnertError: If the underlying C-API call fails.
63
64 Returns:
65 list[tensorinfo]: A list of tensorinfo objects for each input.
66 """
67 num_inputs: int = self.session.input_size()
68 infos: List[tensorinfo] = []
69 for i in range(num_inputs):
70 try:
71 infos.append(self.session.input_tensorinfo(i))
72 except ValueError:
73 raise
74 except Exception as e:
75 raise OnertError(f"Failed to get input tensorinfo #{i}: {e}") from e
76 return infos
77

References validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, and package.common.basesession.BaseSession.session.

Referenced by package.infer.session.session.infer().

◆ get_outputs_tensorinfo()

List[tensorinfo] package.common.basesession.BaseSession.get_outputs_tensorinfo (   self)
Retrieve tensorinfo for all output tensors.

Raises:
    OnertError: If the underlying C-API call fails.

Returns:
    list[tensorinfo]: A list of tensorinfo objects for each output.

Definition at line 78 of file basesession.py.

78 def get_outputs_tensorinfo(self) -> List[tensorinfo]:
79 """
80 Retrieve tensorinfo for all output tensors.
81
82 Raises:
83 OnertError: If the underlying C-API call fails.
84
85 Returns:
86 list[tensorinfo]: A list of tensorinfo objects for each output.
87 """
88 num_outputs: int = self.session.output_size()
89 infos: List[tensorinfo] = []
90 for i in range(num_outputs):
91 try:
92 infos.append(self.session.output_tensorinfo(i))
93 except ValueError:
94 raise
95 except Exception as e:
96 raise OnertError(f"Failed to get output tensorinfo #{i}: {e}") from e
97 return infos
98

References validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, and package.common.basesession.BaseSession.session.

◆ set_inputs()

package.common.basesession.BaseSession.set_inputs (   self,
  size,
  inputs_array = [] 
)
Set the input tensors for the session.

Args:
    size (int): Number of input tensors.
    inputs_array (list): List of numpy arrays for the input data.

Raises:
    ValueError: If session uninitialized.
    OnertError:  If any C-API call fails.

Definition at line 99 of file basesession.py.

99 def set_inputs(self, size, inputs_array=[]):
100 """
101 Set the input tensors for the session.
102
103 Args:
104 size (int): Number of input tensors.
105 inputs_array (list): List of numpy arrays for the input data.
106
107 Raises:
108 ValueError: If session uninitialized.
109 OnertError: If any C-API call fails.
110 """
111 if self.session is None:
112 raise ValueError(
113 "Session is not initialized with a model. Please compile with a model before setting inputs."
114 )
115
116 self.inputs = []
117 for i in range(size):
118 try:
119 input_tensorinfo = self.session.input_tensorinfo(i)
120 except ValueError:
121 raise
122 except Exception as e:
123 raise OnertError(f"Failed to get input tensorinfo #{i}: {e}") from e
124
125 if len(inputs_array) > i:
126 input_array = np.array(inputs_array[i], dtype=input_tensorinfo.dtype)
127 else:
128 print(
129 f"Model's input size is {size}, but given inputs_array size is {len(inputs_array)}.\n{i}-th index input is replaced by an array filled with 0."
130 )
131 input_array = np.zeros((num_elems(input_tensorinfo)),
132 dtype=input_tensorinfo.dtype)
133
134 # Check if the shape of input_array matches the dims of input_tensorinfo
135 if input_array.shape != tuple(input_tensorinfo.dims):
136 # If not, set the input tensor info to match the input_array shape
137 try:
138 input_tensorinfo.rank = len(input_array)
139 input_tensorinfo.dims = list(input_array.shape)
140 self.session.set_input_tensorinfo(i, input_tensorinfo)
141 except Exception as e:
142 raise OnertError(f"Failed to set input tensor info #{i}: {e}") from e
143
144 try:
145 self.session.set_input(i, input_array)
146 except ValueError:
147 raise
148 except Exception as e:
149 raise OnertError(f"Failed to set input #{i}: {e}") from e
150
151 self.inputs.append(input_array)
152
uint64_t num_elems(const nnfw_tensorinfo *ti)
Definition minimal.cc:21

References Operation.inputs, Request.inputs, circlechef::CircleImport.inputs(), crew::Part.inputs, luci::CircleReader.inputs(), luci::PGroup.inputs, luci::pass::Expression.inputs, mio::circle::Reader.inputs(), moco::ModelSignature.inputs(), nnc::sir::CallFunction.inputs, validate_onnx2circle.OnnxRunner.inputs, tflinspect::Reader.inputs(), tflchef::TFliteImport.inputs(), tflread::Reader.inputs(), luci_interpreter::CircleReader.inputs(), onert_micro::core::reader::OMCircleReader.inputs(), moco::tf::test::TFNodeBuildTester.inputs(), moco::test::TFNodeBuildTester.inputs(), moco::test::TFNodeBuildTester.inputs(), luci::CircleFakeQuant.inputs(), moco::TFFakeQuantWithMinMaxVars.inputs(), luci::CircleAddN.inputs(), luci::CircleCustom.inputs(), luci::CircleAddN.inputs(), luci::CircleCustom.inputs(), loco::Graph.inputs(), nnkit::support::onnx::Runner.inputs(), ann::Operation.inputs(), loco::Graph.inputs(), luci::CircleFakeQuant.inputs(), moco::TFFakeQuantWithMinMaxVars.inputs(), onert_micro::execute::OMRuntimeKernel.inputs, nnfw_custom_kernel_params.inputs, package.common.basesession.BaseSession.inputs, onert::exec::IODescription.inputs, package.common.basesession.num_elems(), validate_onnx2circle.OnnxRunner.session, onert::api::python::NNFW_SESSION.session, and package.common.basesession.BaseSession.session.

Field Documentation

◆ inputs

◆ outputs

package.common.basesession.BaseSession.outputs

◆ session


The documentation for this class was generated from the following file: