29 py::class_<NNFW_SESSION>(
m,
"nnfw_session", py::module_local())
31 py::init<const char *, const char *>(), py::arg(
"package_file_path"), py::arg(
"backends"),
32 "Create a new session instance, load model from nnpackage file or directory, "
33 "set available backends and prepare session to be ready for inference\n"
35 "\tpackage_file_path (str): Path to the nnpackage file or unzipped directory to be loaded\n"
36 "\tbackends (str): Available backends on which nnfw uses\n"
37 "\t\tMultiple backends can be set and they must be separated by a semicolon "
38 "(ex: \"acl_cl;cpu\")\n"
39 "\t\tAmong the multiple backends, the 1st element is used as the default backend.")
41 py::arg(
"tensor_info"),
42 "Set input model's tensor info for resizing.\n"
44 "\tindex (int): Index of input to be set (0-indexed)\n"
45 "\ttensor_info (tensorinfo): Tensor info to be set")
53 session.set_input<
float>(index, buffer);
55 py::arg(
"index"), py::arg(
"buffer"),
58 "\tindex (int): Index of input to be set (0-indexed)\n"
59 "\tbuffer (numpy): Raw buffer for input")
63 session.set_input<
int>(index, buffer);
65 py::arg(
"index"), py::arg(
"buffer"),
68 "\tindex (int): Index of input to be set (0-indexed)\n"
69 "\tbuffer (numpy): Raw buffer for input")
73 session.set_input<uint8_t>(index, buffer);
75 py::arg(
"index"), py::arg(
"buffer"),
78 "\tindex (int): Index of input to be set (0-indexed)\n"
79 "\tbuffer (numpy): Raw buffer for input")
83 session.set_input<
bool>(index, buffer);
85 py::arg(
"index"), py::arg(
"buffer"),
88 "\tindex (int): Index of input to be set (0-indexed)\n"
89 "\tbuffer (numpy): Raw buffer for input")
93 session.set_input<int64_t>(index, buffer);
95 py::arg(
"index"), py::arg(
"buffer"),
98 "\tindex (int): Index of input to be set (0-indexed)\n"
99 "\tbuffer (numpy): Raw buffer for input")
103 session.set_input<int8_t>(index, buffer);
105 py::arg(
"index"), py::arg(
"buffer"),
108 "\tindex (int): Index of input to be set (0-indexed)\n"
109 "\tbuffer (numpy): Raw buffer for input")
113 session.set_input<int16_t>(index, buffer);
115 py::arg(
"index"), py::arg(
"buffer"),
118 "\tindex (int): Index of input to be set (0-indexed)\n"
119 "\tbuffer (numpy): Raw buffer for input")
123 session.set_output<
float>(index, buffer);
125 py::arg(
"index"), py::arg(
"buffer"),
126 "Set output buffer\n"
128 "\tindex (int): Index of output to be set (0-indexed)\n"
129 "\tbuffer (numpy): Raw buffer for output")
133 session.set_output<
int>(index, buffer);
135 py::arg(
"index"), py::arg(
"buffer"),
136 "Set output buffer\n"
138 "\tindex (int): Index of output to be set (0-indexed)\n"
139 "\tbuffer (numpy): Raw buffer for output")
143 session.set_output<uint8_t>(index, buffer);
145 py::arg(
"index"), py::arg(
"buffer"),
146 "Set output buffer\n"
148 "\tindex (int): Index of output to be set (0-indexed)\n"
149 "\tbuffer (numpy): Raw buffer for output")
153 session.set_output<
bool>(index, buffer);
155 py::arg(
"index"), py::arg(
"buffer"),
156 "Set output buffer\n"
158 "\tindex (int): Index of output to be set (0-indexed)\n"
159 "\tbuffer (numpy): Raw buffer for output")
163 session.set_output<int64_t>(index, buffer);
165 py::arg(
"index"), py::arg(
"buffer"),
166 "Set output buffer\n"
168 "\tindex (int): Index of output to be set (0-indexed)\n"
169 "\tbuffer (numpy): Raw buffer for output")
173 session.set_output<int8_t>(index, buffer);
175 py::arg(
"index"), py::arg(
"buffer"),
176 "Set output buffer\n"
178 "\tindex (int): Index of output to be set (0-indexed)\n"
179 "\tbuffer (numpy): Raw buffer for output")
183 session.set_output<int16_t>(index, buffer);
185 py::arg(
"index"), py::arg(
"buffer"),
186 "Set output buffer\n"
188 "\tindex (int): Index of output to be set (0-indexed)\n"
189 "\tbuffer (numpy): Raw buffer for output")
191 "Get the number of inputs defined in loaded model\n"
193 "\tint: The number of inputs")
195 "Get the number of outputs defined in loaded model\n"
197 "\tint: The number of outputs")
199 py::arg(
"layout") =
"NONE",
200 "Set the layout of an input\n"
202 "\tindex (int): Index of input to be set (0-indexed)\n"
203 "\tlayout (str): Layout to set to target input")
205 py::arg(
"layout") =
"NONE",
206 "Set the layout of an output\n"
208 "\tindex (int): Index of output to be set (0-indexed)\n"
209 "\tlayout (str): Layout to set to target output")
211 "Get i-th input tensor info\n"
213 "\tindex (int): Index of input\n"
215 "\ttensorinfo: Tensor info (shape, type, etc)")
217 "Get i-th output tensor info\n"
219 "\tindex (int): Index of output\n"
221 "\ttensorinfo: Tensor info (shape, type, etc)");
228 m.attr(
"nnfw_session")
229 .cast<py::class_<NNFW_SESSION>>()
231 "Retrieve training information for the model.")
233 "Set training information for the model.")
236 "Run a training step, optionally updating weights.")
238 "Retrieve the training loss for a specific index.")
239 .def(
"train_set_input", &NNFW_SESSION::train_set_input<float>, py::arg(
"index"),
240 py::arg(
"buffer"),
"Set training input tensor for the given index (float).")
241 .def(
"train_set_input", &NNFW_SESSION::train_set_input<int>, py::arg(
"index"),
242 py::arg(
"buffer"),
"Set training input tensor for the given index (int).")
243 .def(
"train_set_input", &NNFW_SESSION::train_set_input<uint8_t>, py::arg(
"index"),
244 py::arg(
"buffer"),
"Set training input tensor for the given index (uint8).")
245 .def(
"train_set_expected", &NNFW_SESSION::train_set_expected<float>, py::arg(
"index"),
246 py::arg(
"buffer"),
"Set expected output tensor for the given index (float).")
247 .def(
"train_set_expected", &NNFW_SESSION::train_set_expected<int>, py::arg(
"index"),
248 py::arg(
"buffer"),
"Set expected output tensor for the given index (int).")
249 .def(
"train_set_expected", &NNFW_SESSION::train_set_expected<uint8_t>, py::arg(
"index"),
250 py::arg(
"buffer"),
"Set expected output tensor for the given index (uint8).")
251 .def(
"train_set_output", &NNFW_SESSION::train_set_output<float>, py::arg(
"index"),
252 py::arg(
"buffer"),
"Set output tensor for the given index (float).")
253 .def(
"train_set_output", &NNFW_SESSION::train_set_output<int>, py::arg(
"index"),
254 py::arg(
"buffer"),
"Set output tensor for the given index (int).")
255 .def(
"train_set_output", &NNFW_SESSION::train_set_output<uint8_t>, py::arg(
"index"),
256 py::arg(
"buffer"),
"Set output tensor for the given index (uint8).")
258 "Export the trained model to a circle file.")
260 "Import a training checkpoint from a file.")
262 "Export the training checkpoint to a file.");