ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnfw_session Struct Reference

Public Member Functions

 ~nnfw_session ()
 
NNFW_STATUS load_model_from_file (const char *package_file_path)
 
NNFW_STATUS train_set_traininfo (const nnfw_train_info *info)
 
NNFW_STATUS train_prepare ()
 
NNFW_STATUS train_input_tensorinfo (uint32_t index, nnfw_tensorinfo *ti)
 
NNFW_STATUS train_expected_tensorinfo (uint32_t index, nnfw_tensorinfo *ti)
 
NNFW_STATUS train_set_input (uint32_t index, void *input)
 
NNFW_STATUS train_set_expected (uint32_t index, void *expected)
 
NNFW_STATUS train_set_output (uint32_t index, NNFW_TYPE type, void *buffer, size_t length)
 
NNFW_STATUS train_run (bool update_weights)
 
NNFW_STATUS train_get_loss (uint32_t index, float *loss)
 
NNFW_STATUS train_export_circle (const char *path)
 
NNFW_STATUS train_export_checkpoint (const char *path)
 
NNFW_STATUS train_import_checkpoint (const char *path)
 

Static Public Member Functions

static NNFW_STATUS create (nnfw_session **session)
 Factory method. It creates and initialize nnfw_session.
 

Detailed Description

Definition at line 68 of file onert-micro.cpp.

Constructor & Destructor Documentation

◆ ~nnfw_session()

nnfw_session::~nnfw_session ( )

Definition at line 165 of file onert-micro.cpp.

165{ delete _train_interpreter; }

Member Function Documentation

◆ create()

NNFW_STATUS nnfw_session::create ( nnfw_session **  session)
static

Factory method. It creates and initialize nnfw_session.

Note
Use factory instead of constructor to get status

Definition at line 149 of file onert-micro.cpp.

150{
151 if (session == nullptr)
153
154 auto new_session = std::unique_ptr<nnfw_session>(new nnfw_session());
155 *session = new_session.release();
156
157 if (*session == nullptr)
158 {
159 return NNFW_STATUS_ERROR;
160 }
161
163}
SessionID session(const coco::Module *m)
Definition Session.cpp:48
@ NNFW_STATUS_UNEXPECTED_NULL
Definition onert-micro.h:95
@ NNFW_STATUS_NO_ERROR
Definition onert-micro.h:88
@ NNFW_STATUS_ERROR
Definition onert-micro.h:93

References NNFW_STATUS_ERROR, NNFW_STATUS_NO_ERROR, and NNFW_STATUS_UNEXPECTED_NULL.

Referenced by nnfw_create_session().

◆ load_model_from_file()

NNFW_STATUS nnfw_session::load_model_from_file ( const char *  package_file_path)

Definition at line 272 of file onert-micro.cpp.

273{
274 _model_buf = readFile(file_path);
275 _config.model_ptr = _model_buf.data();
276 _config.model_size = _model_buf.size();
277 // load training info
278 loadTrainingInfo(_config.model_ptr);
279 // TODO: this import should start on nnfw_prepare if inference_interpreter is introduced
280 _train_interpreter->importTrainModel(_config.model_ptr, _config);
282}
OMStatus importTrainModel(char *model_ptr, const OMConfig &config)
DataBuffer readFile(const char *path)

References onert_micro::OMTrainingInterpreter::importTrainModel(), onert_micro::OMConfig::model_ptr, onert_micro::OMConfig::model_size, NNFW_STATUS_NO_ERROR, and readFile().

◆ train_expected_tensorinfo()

NNFW_STATUS nnfw_session::train_expected_tensorinfo ( uint32_t  index,
nnfw_tensorinfo ti 
)

◆ train_export_checkpoint()

NNFW_STATUS nnfw_session::train_export_checkpoint ( const char *  path)

Definition at line 324 of file onert-micro.cpp.

325{
326 _train_interpreter->saveCheckpoint(_config, path);
328}
OMStatus saveCheckpoint(const OMConfig &config, const char *save_path)

References NNFW_STATUS_NO_ERROR, and onert_micro::OMTrainingInterpreter::saveCheckpoint().

◆ train_export_circle()

NNFW_STATUS nnfw_session::train_export_circle ( const char *  path)

Definition at line 318 of file onert-micro.cpp.

319{
320 _train_interpreter->saveModel(_config, path);
322}
OMStatus saveModel(const OMConfig &config, const char *save_path)

References NNFW_STATUS_NO_ERROR, and onert_micro::OMTrainingInterpreter::saveModel().

◆ train_get_loss()

NNFW_STATUS nnfw_session::train_get_loss ( uint32_t  index,
float *  loss 
)

Definition at line 370 of file onert-micro.cpp.

371{
373 switch (_config.training_context.loss)
374 {
377 break;
378 default:
380 break;
381 }
382
383 _train_interpreter->evaluateMetric(_config, m, reinterpret_cast<void *>(loss),
386}
OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val, uint32_t test_size)
@ CROSS_ENTROPY
Definition OMConfig.h:54
@ CROSS_ENTROPY_METRICS
Definition OMConfig.h:42
OMTrainingContext training_context
Definition OMConfig.h:107

References onert_micro::OMTrainingContext::batch_size, onert_micro::CROSS_ENTROPY, onert_micro::CROSS_ENTROPY_METRICS, onert_micro::OMTrainingInterpreter::evaluateMetric(), onert_micro::OMTrainingContext::loss, m, NNFW_STATUS_NO_ERROR, and onert_micro::OMConfig::training_context.

◆ train_import_checkpoint()

NNFW_STATUS nnfw_session::train_import_checkpoint ( const char *  path)

Definition at line 330 of file onert-micro.cpp.

331{
332 _train_interpreter->loadCheckpoint(_config, path);
334}
OMStatus loadCheckpoint(OMConfig &config, const char *load_path)

References onert_micro::OMTrainingInterpreter::loadCheckpoint(), and NNFW_STATUS_NO_ERROR.

◆ train_input_tensorinfo()

NNFW_STATUS nnfw_session::train_input_tensorinfo ( uint32_t  index,
nnfw_tensorinfo ti 
)

◆ train_prepare()

NNFW_STATUS nnfw_session::train_prepare ( )

Definition at line 284 of file onert-micro.cpp.

285{
286 // TODO: Implement remaining jobs if inference_interpreter is introduced
287 // maybe interpreter initialization ?
289}

References NNFW_STATUS_NO_ERROR.

◆ train_run()

NNFW_STATUS nnfw_session::train_run ( bool  update_weights)

Definition at line 291 of file onert-micro.cpp.

292{
293 if (update_weights)
294 {
295 // TOOD: micro support update_weights ???
296 // Here we use this flag for distinguish inference and train in trainaing interpreter
297 _train_interpreter->trainSingleStep(_config);
300 }
301 else
302 {
303 // TODO: support multiple input/output
304 assert(outputbuf != nullptr);
305 _train_interpreter->allocateInputs();
306 float *allocated_input_data = (float *)_train_interpreter->getInputDataAt(0);
307 float *user_input_data = (float *)_train_interpreter->getInputData(0);
308 memcpy(allocated_input_data, user_input_data,
309 sizeof(float) * _train_interpreter->getInputSizeAt(0));
310 _train_interpreter->run(_config);
311 float *calculated_ptr = (float *)_train_interpreter->getOutputDataAt(0);
312 memcpy(outputbuf, calculated_ptr, sizeof(float) * _train_interpreter->getOutputSizeAt(0));
313 _train_interpreter->reset();
314 }
316}
uint32_t getOutputSizeAt(uint32_t position)
OMStatus run(const OMConfig &config)
uint32_t getInputSizeAt(uint32_t position)
OMStatus trainSingleStep(OMConfig &config)

References onert_micro::OMTrainingInterpreter::allocateInputs(), onert_micro::OMTrainingContext::batch_size, onert_micro::OMTrainingInterpreter::getInputData(), onert_micro::OMTrainingInterpreter::getInputDataAt(), onert_micro::OMTrainingInterpreter::getInputSizeAt(), onert_micro::OMTrainingInterpreter::getOutputDataAt(), onert_micro::OMTrainingInterpreter::getOutputSizeAt(), NNFW_STATUS_NO_ERROR, onert_micro::OMTrainingContext::num_epoch, onert_micro::OMTrainingContext::num_step, onert_micro::OMTrainingInterpreter::reset(), onert_micro::OMTrainingInterpreter::run(), onert_micro::OMConfig::training_context, and onert_micro::OMTrainingInterpreter::trainSingleStep().

◆ train_set_expected()

NNFW_STATUS nnfw_session::train_set_expected ( uint32_t  index,
void *  expected 
)

Definition at line 344 of file onert-micro.cpp.

345{
346 _train_interpreter->setTarget((uint8_t *)expected, index);
348}
void setTarget(uint8_t *data, uint32_t target_index)

References NNFW_STATUS_NO_ERROR, and onert_micro::OMTrainingInterpreter::setTarget().

◆ train_set_input()

NNFW_STATUS nnfw_session::train_set_input ( uint32_t  index,
void *  input 
)

Definition at line 337 of file onert-micro.cpp.

338{
339 _train_interpreter->setInput((uint8_t *)input, index);
341}
void setInput(uint8_t *data, uint32_t input_index)

References NNFW_STATUS_NO_ERROR, and onert_micro::OMTrainingInterpreter::setInput().

◆ train_set_output()

NNFW_STATUS nnfw_session::train_set_output ( uint32_t  index,
NNFW_TYPE  type,
void *  buffer,
size_t  length 
)

Definition at line 350 of file onert-micro.cpp.

352{
353 outputbuf = (uint8_t *)buffer;
355}

References NNFW_STATUS_NO_ERROR.

◆ train_set_traininfo()

NNFW_STATUS nnfw_session::train_set_traininfo ( const nnfw_train_info info)

Definition at line 357 of file onert-micro.cpp.

358{
359 _config.training_context.learning_rate = info->learning_rate;
360 _config.training_context.batch_size = info->batch_size;
363 _config.training_context.beta = info->adam_opt.beta;
364 _config.training_context.beta_squares = info->adam_opt.beta2;
365 _config.training_context.beta = info->adam_opt.epsilon;
366 _config.training_context.num_of_train_layers = info->num_trainble_ops;
368}
volatile const char info[]
@ NNFW_TRAIN_OPTIMIZER_ADAM
OMTrainOptimizer optimizer
Definition OMConfig.h:78

References onert_micro::ADAM, onert_micro::OMTrainingContext::batch_size, onert_micro::OMTrainingContext::beta, onert_micro::OMTrainingContext::beta_squares, info, onert_micro::OMTrainingContext::learning_rate, NNFW_STATUS_NO_ERROR, NNFW_TRAIN_OPTIMIZER_ADAM, onert_micro::OMTrainingContext::num_of_train_layers, onert_micro::OMTrainingContext::optimizer, onert_micro::SGD, and onert_micro::OMConfig::training_context.


The documentation for this struct was generated from the following file: