26#include <circle-generated/circle/schema_generated.h>
27#include <circle-generated/circle/traininfo_generated.h>
29#define NNFW_RETURN_ERROR_IF_NULL(p) \
33 return NNFW_STATUS_UNEXPECTED_NULL; \
41 std::ifstream file(path, std::ios::binary | std::ios::in);
44 std::string errmsg =
"Failed to open file";
45 std::cerr << errmsg << std::endl;
49 file.seekg(0, std::ios::end);
50 auto fileSize = file.tellg();
51 file.seekg(0, std::ios::beg);
57 file.read(model_data.data(), fileSize);
60 std::string errmsg =
"Failed to read file";
61 std::cerr << errmsg << std::endl;
102 uint32_t getInputSize();
103 uint32_t getOutputSize();
105 NNFW_STATUS loadOptimizerInfo(
const circle::ModelTraining *circle_model);
106 NNFW_STATUS loadLossInfo(
const circle::ModelTraining *circle_model);
107 NNFW_STATUS loadTrainableOps(
const circle::ModelTraining *circle_model,
int num_ops);
113 std::string _model_path;
117nnfw_session::nnfw_session() : _train_interpreter{new
onert_micro::OMTrainingInterpreter()}
121 const uint32_t training_epochs = 10;
122 const float learning_rate = 0.001f;
123 const uint32_t num_train_layers = 10;
126 const float beta = 0.9;
127 const float beta_squares = 0.999;
128 const float epsilon = 1e-07;
136 train_context.
loss = loss;
138 train_context.
beta = beta;
140 train_context.
epsilon = epsilon;
151 if (session ==
nullptr)
154 auto new_session = std::unique_ptr<nnfw_session>(
new nnfw_session());
155 *session = new_session.release();
157 if (*session ==
nullptr)
167NNFW_STATUS nnfw_session::loadOptimizerInfo(
const circle::ModelTraining *circle_model)
169 assert(circle_model !=
nullptr);
171 const circle::Optimizer circle_opt = circle_model->optimizer();
175 case circle::Optimizer_SGD:
178 circle_model->optimizer_opt_as_SGDOptions()->learning_rate();
180 case circle::Optimizer_ADAM:
183 circle_model->optimizer_opt_as_AdamOptions()->learning_rate();
186 circle_model->optimizer_opt_as_AdamOptions()->beta_2();
190 std::cerr <<
"unknown optimzer" << std::endl;
196NNFW_STATUS nnfw_session::loadLossInfo(
const circle::ModelTraining *circle_model)
198 assert(circle_model !=
nullptr);
201 const circle::LossFn circle_loss =
circle_model->lossfn();
205 case circle::LossFn::LossFn_CATEGORICAL_CROSSENTROPY:
208 case circle::LossFn::LossFn_MEAN_SQUARED_ERROR:
211 case circle::LossFn::LossFn_SPARSE_CATEGORICAL_CROSSENTROPY:
213 std::cerr <<
"'sparse_categorical_crossentropy' is not supported yet" << std::endl;
216 std::cerr <<
"unknown loss function" << std::endl;
222NNFW_STATUS nnfw_session::loadTrainableOps(
const circle::ModelTraining *circle_model,
int num_ops)
224 assert(circle_model !=
nullptr);
228 if (ops_list !=
nullptr)
230 num_ops - ops_list->data()[0];
236NNFW_STATUS nnfw_session::loadTrainingInfo(
char *buf)
238 auto model = circle::GetModel(buf);
239 auto num_ops =
model->subgraphs()->Get(0)->operators()->size();
241 auto const metadata_list =
model->metadata();
242 const uint8_t *
data =
nullptr;
243 if (metadata_list !=
nullptr)
245 for (uint32_t i = 0; i < metadata_list->size(); ++i)
247 const auto metadata = metadata_list->Get(i);
248 if (strcmp(metadata->name()->c_str(),
"CIRCLE_TRAINING") != 0)
250 data = (
model->buffers()->Get(metadata->buffer()))->
data()->data();
254 const circle::ModelTraining *traininfo_model =
255 circle::GetModelTraining(
static_cast<const void *
>(data));
257 NNFW_STATUS status = loadOptimizerInfo(traininfo_model);
261 status = loadLossInfo(traininfo_model);
265 status = loadTrainableOps(traininfo_model, num_ops);
304 assert(outputbuf !=
nullptr);
306 float *allocated_input_data = (
float *)_train_interpreter->
getInputDataAt(0);
307 float *user_input_data = (
float *)_train_interpreter->
getInputData(0);
308 memcpy(allocated_input_data, user_input_data,
310 _train_interpreter->
run(_config);
311 float *calculated_ptr = (
float *)_train_interpreter->
getOutputDataAt(0);
312 memcpy(outputbuf, calculated_ptr,
sizeof(
float) * _train_interpreter->
getOutputSizeAt(0));
313 _train_interpreter->
reset();
320 _train_interpreter->
saveModel(_config, path);
339 _train_interpreter->
setInput((uint8_t *)input, index);
346 _train_interpreter->
setTarget((uint8_t *)expected, index);
353 outputbuf = (uint8_t *)buffer;
383 _train_interpreter->
evaluateMetric(_config,
m,
reinterpret_cast<void *
>(loss),
394 return session->load_model_from_file(package_file_path);
401 return session->train_run(update_weights);
406 return session->train_export_circle(path);
411 return session->train_export_checkpoint(path);
416 return session->train_import_checkpoint(path);
423 return session->train_set_input(index, input);
430 return session->train_set_expected(index, expected);
436 return session->train_get_loss(index, loss);
442 return session->train_set_traininfo(
info);
446 void *buffer,
size_t length)
449 return session->train_set_output(index, type, buffer, length);
uint32_t getOutputSizeAt(uint32_t position)
OMStatus run(const OMConfig &config)
OMStatus saveModel(const OMConfig &config, const char *save_path)
void * getInputData(uint32_t position)
void setTarget(uint8_t *data, uint32_t target_index)
OMStatus importTrainModel(char *model_ptr, const OMConfig &config)
void setInput(uint8_t *data, uint32_t input_index)
void * getInputDataAt(uint32_t position)
void * getOutputDataAt(uint32_t position)
uint32_t getInputSizeAt(uint32_t position)
OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val, uint32_t test_size)
OMStatus allocateInputs()
OMStatus trainSingleStep(OMConfig &config)
OMStatus loadCheckpoint(OMConfig &config, const char *load_path)
OMStatus saveCheckpoint(const OMConfig &config, const char *save_path)
volatile const char info[]
NNFW_STATUS nnfw_create_session(nnfw_session **session)
Create a new session instance.
NNFW_STATUS nnfw_train_get_loss(nnfw_session *session, uint32_t index, float *loss)
Get loss value for expected output.
#define NNFW_RETURN_ERROR_IF_NULL(p)
NNFW_STATUS nnfw_train_export_checkpoint(nnfw_session *session, const char *path)
Export circle checkpoint.
NNFW_STATUS nnfw_load_model_from_file(nnfw_session *session, const char *package_file_path)
Load model from nnpackage file or directory.
NNFW_STATUS nnfw_train_prepare(nnfw_session *session)
Prepare session to be ready for training.
NNFW_STATUS nnfw_train(nnfw_session *session, bool update_weights)
Train the model.
NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_info *info)
Set training information.
DataBuffer readFile(const char *path)
NNFW_STATUS nnfw_train_set_input(nnfw_session *session, uint32_t index, void *input, const nnfw_tensorinfo *input_info)
Set training input.
NNFW_STATUS nnfw_train_set_output(nnfw_session *session, uint32_t index, NNFW_TYPE type, void *buffer, size_t length)
Set training output buffer.
NNFW_STATUS nnfw_train_import_checkpoint(nnfw_session *session, const char *path)
Import circle checkpoint.
NNFW_STATUS nnfw_train_set_expected(nnfw_session *session, uint32_t index, void *expected, const nnfw_tensorinfo *expected_info)
Set training expected output.
NNFW_STATUS nnfw_train_export_circle(nnfw_session *session, const char *path)
Export current training model into circle model.
std::vector< char > DataBuffer
NNFW_STATUS
Result values returned from a call to an API function.
@ NNFW_STATUS_UNEXPECTED_NULL
@ NNFW_TRAIN_OPTIMIZER_ADAM
NNFW_STATUS train_prepare()
NNFW_STATUS train_set_output(uint32_t index, NNFW_TYPE type, void *buffer, size_t length)
NNFW_STATUS train_set_expected(uint32_t index, void *expected)
NNFW_STATUS train_run(bool update_weights)
NNFW_STATUS train_export_checkpoint(const char *path)
NNFW_STATUS train_import_checkpoint(const char *path)
NNFW_STATUS train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti)
NNFW_STATUS train_set_input(uint32_t index, void *input)
NNFW_STATUS train_get_loss(uint32_t index, float *loss)
NNFW_STATUS load_model_from_file(const char *package_file_path)
NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti)
NNFW_STATUS train_export_circle(const char *path)
static NNFW_STATUS create(nnfw_session **session)
Factory method. It creates and initialize nnfw_session.
NNFW_STATUS train_set_traininfo(const nnfw_train_info *info)
tensor info describes the type and shape of tensors
Training information to prepare training.
OMTrainingContext training_context
OMTrainOptimizer optimizer
uint32_t num_of_train_layers