ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnfw_traininfo_bindings.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18
19#include "nnfw_api_wrapper.h"
20
21namespace py = pybind11;
22
23using namespace onert::api::python;
24
25// Declare binding train enums
26void bind_nnfw_train_enums(py::module_ &m)
27{
28 // Bind NNFW_TRAIN_LOSS
29 py::enum_<NNFW_TRAIN_LOSS>(m, "loss", py::module_local())
30 .value("UNDEFINED", NNFW_TRAIN_LOSS_UNDEFINED)
31 .value("MEAN_SQUARED_ERROR", NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR)
32 .value("CATEGORICAL_CROSSENTROPY", NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY);
33
34 // Bind NNFW_TRAIN_LOSS_REDUCTION
35 py::enum_<NNFW_TRAIN_LOSS_REDUCTION>(m, "loss_reduction", py::module_local())
36 .value("UNDEFINED", NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED)
37 .value("SUM_OVER_BATCH_SIZE", NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE)
38 .value("SUM", NNFW_TRAIN_LOSS_REDUCTION_SUM);
39
40 // Bind NNFW_TRAIN_OPTIMIZER
41 py::enum_<NNFW_TRAIN_OPTIMIZER>(m, "optimizer", py::module_local())
42 .value("UNDEFINED", NNFW_TRAIN_OPTIMIZER_UNDEFINED)
43 .value("SGD", NNFW_TRAIN_OPTIMIZER_SGD)
44 .value("ADAM", NNFW_TRAIN_OPTIMIZER_ADAM);
45
46 // Bind NNFW_TRAIN_NUM_OF_TRAINABLE_OPS_SPECIAL_VALUES
47 py::enum_<NNFW_TRAIN_NUM_OF_TRAINABLE_OPS_SPECIAL_VALUES>(m, "trainable_ops", py::module_local())
48 .value("INCORRECT_STATE", NNFW_TRAIN_TRAINABLE_INCORRECT_STATE)
49 .value("ALL", NNFW_TRAIN_TRAINABLE_ALL)
50 .value("NONE", NNFW_TRAIN_TRAINABLE_NONE);
51}
52
53// Declare binding loss info
54void bind_nnfw_loss_info(py::module_ &m)
55{
56 py::class_<nnfw_loss_info>(m, "lossinfo", py::module_local())
57 .def(py::init<>()) // Default constructor
58 .def_readwrite("loss", &nnfw_loss_info::loss, "Loss type")
59 .def_readwrite("reduction_type", &nnfw_loss_info::reduction_type, "Reduction type");
60}
61
62// Declare binding train info
63void bind_nnfw_train_info(py::module_ &m)
64{
65 py::class_<nnfw_train_info>(m, "traininfo", py::module_local())
66 .def(py::init<>()) // Default constructor
67 .def_readwrite("learning_rate", &nnfw_train_info::learning_rate, "Learning rate")
68 .def_readwrite("batch_size", &nnfw_train_info::batch_size, "Batch size")
69 .def_readwrite("loss_info", &nnfw_train_info::loss_info, "Loss information")
70 .def_readwrite("opt", &nnfw_train_info::opt, "Optimizer type")
71 .def_readwrite("num_of_trainable_ops", &nnfw_train_info::num_of_trainable_ops,
72 "Number of trainable operations");
73}
@ NNFW_TRAIN_TRAINABLE_NONE
@ NNFW_TRAIN_TRAINABLE_ALL
@ NNFW_TRAIN_TRAINABLE_INCORRECT_STATE
void bind_nnfw_train_enums(py::module_ &m)
void bind_nnfw_train_info(py::module_ &m)
void bind_nnfw_loss_info(py::module_ &m)
@ NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED
@ NNFW_TRAIN_LOSS_REDUCTION_SUM
@ NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE
@ NNFW_TRAIN_OPTIMIZER_ADAM
@ NNFW_TRAIN_OPTIMIZER_SGD
@ NNFW_TRAIN_OPTIMIZER_UNDEFINED
@ NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR
@ NNFW_TRAIN_LOSS_UNDEFINED
@ NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY
NNFW_TRAIN_LOSS_REDUCTION reduction_type
NNFW_TRAIN_LOSS loss
uint32_t batch_size
nnfw_loss_info loss_info
NNFW_TRAIN_OPTIMIZER opt