ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnfw_traininfo_bindings.cc File Reference

Go to the source code of this file.

Functions

void bind_nnfw_train_enums (py::module_ &m)
 
void bind_nnfw_loss_info (py::module_ &m)
 
void bind_nnfw_train_info (py::module_ &m)
 

Function Documentation

◆ bind_nnfw_loss_info()

void bind_nnfw_loss_info ( py::module_ &  m)

Definition at line 54 of file nnfw_traininfo_bindings.cc.

55{
56 py::class_<nnfw_loss_info>(m, "lossinfo", py::module_local())
57 .def(py::init<>()) // Default constructor
58 .def_readwrite("loss", &nnfw_loss_info::loss, "Loss type")
59 .def_readwrite("reduction_type", &nnfw_loss_info::reduction_type, "Reduction type");
60}
NNFW_TRAIN_LOSS_REDUCTION reduction_type
NNFW_TRAIN_LOSS loss

References nnfw_loss_info::loss, m, and nnfw_loss_info::reduction_type.

Referenced by PYBIND11_MODULE().

◆ bind_nnfw_train_enums()

void bind_nnfw_train_enums ( py::module_ &  m)

Definition at line 26 of file nnfw_traininfo_bindings.cc.

27{
28 // Bind NNFW_TRAIN_LOSS
29 py::enum_<NNFW_TRAIN_LOSS>(m, "loss", py::module_local())
30 .value("UNDEFINED", NNFW_TRAIN_LOSS_UNDEFINED)
31 .value("MEAN_SQUARED_ERROR", NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR)
32 .value("CATEGORICAL_CROSSENTROPY", NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY);
33
34 // Bind NNFW_TRAIN_LOSS_REDUCTION
35 py::enum_<NNFW_TRAIN_LOSS_REDUCTION>(m, "loss_reduction", py::module_local())
36 .value("UNDEFINED", NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED)
37 .value("SUM_OVER_BATCH_SIZE", NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE)
38 .value("SUM", NNFW_TRAIN_LOSS_REDUCTION_SUM);
39
40 // Bind NNFW_TRAIN_OPTIMIZER
41 py::enum_<NNFW_TRAIN_OPTIMIZER>(m, "optimizer", py::module_local())
42 .value("UNDEFINED", NNFW_TRAIN_OPTIMIZER_UNDEFINED)
43 .value("SGD", NNFW_TRAIN_OPTIMIZER_SGD)
44 .value("ADAM", NNFW_TRAIN_OPTIMIZER_ADAM);
45
46 // Bind NNFW_TRAIN_NUM_OF_TRAINABLE_OPS_SPECIAL_VALUES
47 py::enum_<NNFW_TRAIN_NUM_OF_TRAINABLE_OPS_SPECIAL_VALUES>(m, "trainable_ops", py::module_local())
48 .value("INCORRECT_STATE", NNFW_TRAIN_TRAINABLE_INCORRECT_STATE)
49 .value("ALL", NNFW_TRAIN_TRAINABLE_ALL)
50 .value("NONE", NNFW_TRAIN_TRAINABLE_NONE);
51}
@ NNFW_TRAIN_TRAINABLE_NONE
@ NNFW_TRAIN_TRAINABLE_ALL
@ NNFW_TRAIN_TRAINABLE_INCORRECT_STATE
@ NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED
@ NNFW_TRAIN_LOSS_REDUCTION_SUM
@ NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE
@ NNFW_TRAIN_OPTIMIZER_ADAM
@ NNFW_TRAIN_OPTIMIZER_SGD
@ NNFW_TRAIN_OPTIMIZER_UNDEFINED
@ NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR
@ NNFW_TRAIN_LOSS_UNDEFINED
@ NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY

References m, NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY, NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR, NNFW_TRAIN_LOSS_REDUCTION_SUM, NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE, NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED, NNFW_TRAIN_LOSS_UNDEFINED, NNFW_TRAIN_OPTIMIZER_ADAM, NNFW_TRAIN_OPTIMIZER_SGD, NNFW_TRAIN_OPTIMIZER_UNDEFINED, NNFW_TRAIN_TRAINABLE_ALL, NNFW_TRAIN_TRAINABLE_INCORRECT_STATE, and NNFW_TRAIN_TRAINABLE_NONE.

Referenced by PYBIND11_MODULE().

◆ bind_nnfw_train_info()

void bind_nnfw_train_info ( py::module_ &  m)

Definition at line 63 of file nnfw_traininfo_bindings.cc.

64{
65 py::class_<nnfw_train_info>(m, "traininfo", py::module_local())
66 .def(py::init<>()) // Default constructor
67 .def_readwrite("learning_rate", &nnfw_train_info::learning_rate, "Learning rate")
68 .def_readwrite("batch_size", &nnfw_train_info::batch_size, "Batch size")
69 .def_readwrite("loss_info", &nnfw_train_info::loss_info, "Loss information")
70 .def_readwrite("opt", &nnfw_train_info::opt, "Optimizer type")
71 .def_readwrite("num_of_trainable_ops", &nnfw_train_info::num_of_trainable_ops,
72 "Number of trainable operations");
73}
uint32_t batch_size
nnfw_loss_info loss_info
NNFW_TRAIN_OPTIMIZER opt

References nnfw_train_info::batch_size, nnfw_train_info::learning_rate, nnfw_train_info::loss_info, m, nnfw_train_info::num_of_trainable_ops, and nnfw_train_info::opt.

Referenced by PYBIND11_MODULE().