ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert_micro::OMTrainingInterpreter Class Reference

#include <OMTrainingInterpreter.h>

Public Member Functions

 OMTrainingInterpreter ()=default
 
 OMTrainingInterpreter (const OMTrainingInterpreter &)=delete
 
 OMTrainingInterpreter (OMTrainingInterpreter &&)=delete
 
OMTrainingInterpreteroperator= (const OMTrainingInterpreter &)=delete
 
OMTrainingInterpreter && operator= (const OMTrainingInterpreter &&)=delete
 
 ~OMTrainingInterpreter ()=default
 
OMStatus importTrainModel (char *model_ptr, const OMConfig &config)
 
void setInput (uint8_t *data, uint32_t input_index)
 
void setTarget (uint8_t *data, uint32_t target_index)
 
OMStatus trainSingleStep (OMConfig &config)
 
OMStatus reset ()
 
OMStatus evaluateMetric (const OMConfig &config, OMMetrics metric, void *metric_val, uint32_t test_size)
 
uint32_t getInputSizeAt (uint32_t position)
 
uint32_t getOutputSizeAt (uint32_t position)
 
OMStatus saveModel (const OMConfig &config, const char *save_path)
 
OMStatus saveCheckpoint (const OMConfig &config, const char *save_path)
 
OMStatus loadCheckpoint (OMConfig &config, const char *load_path)
 
OMStatus run (const OMConfig &config)
 
OMStatus allocateInputs ()
 
void * getInputData (uint32_t position)
 
void * getInputDataAt (uint32_t position)
 
void * getOutputDataAt (uint32_t position)
 

Detailed Description

Definition at line 28 of file OMTrainingInterpreter.h.

Constructor & Destructor Documentation

◆ OMTrainingInterpreter() [1/3]

onert_micro::OMTrainingInterpreter::OMTrainingInterpreter ( )
default

◆ OMTrainingInterpreter() [2/3]

onert_micro::OMTrainingInterpreter::OMTrainingInterpreter ( const OMTrainingInterpreter )
delete

◆ OMTrainingInterpreter() [3/3]

onert_micro::OMTrainingInterpreter::OMTrainingInterpreter ( OMTrainingInterpreter &&  )
delete

◆ ~OMTrainingInterpreter()

onert_micro::OMTrainingInterpreter::~OMTrainingInterpreter ( )
default

Member Function Documentation

◆ allocateInputs()

OMStatus onert_micro::OMTrainingInterpreter::allocateInputs ( )
inline

Definition at line 91 of file OMTrainingInterpreter.h.

91{ return _training_runtime_module.allocateInputs(); }

References onert_micro::core::OMRuntimeModule::allocateInputs().

Referenced by nnfw_session::train_run().

◆ evaluateMetric()

OMStatus onert_micro::OMTrainingInterpreter::evaluateMetric ( const OMConfig config,
OMMetrics  metric,
void *  metric_val,
uint32_t  test_size 
)
inline

Definition at line 71 of file OMTrainingInterpreter.h.

73 {
74 return _training_runtime_module.evaluateMetric(config, metric, metric_val, test_size);
75 }
OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val, uint32_t test_size)

References onert_micro::core::OMTrainingRuntimeModule::evaluateMetric().

Referenced by entry(), training_configure_tool::runTrainProcessWithCurConfig(), and nnfw_session::train_get_loss().

◆ getInputData()

void * OMTrainingInterpreter::getInputData ( uint32_t  position)

Definition at line 166 of file OMTrainingInterpreter.cpp.

167{
168 return _training_runtime_module.getInputData(position);
169}

References onert_micro::core::OMTrainingRuntimeModule::getInputData().

Referenced by nnfw_session::train_run().

◆ getInputDataAt()

void * OMTrainingInterpreter::getInputDataAt ( uint32_t  position)

Definition at line 156 of file OMTrainingInterpreter.cpp.

157{
158 return _training_runtime_module.getInputDataAt(position);
159}
void * getInputDataAt(uint32_t position)

References onert_micro::core::OMRuntimeModule::getInputDataAt().

Referenced by nnfw_session::train_run().

◆ getInputSizeAt()

uint32_t OMTrainingInterpreter::getInputSizeAt ( uint32_t  position)

Definition at line 41 of file OMTrainingInterpreter.cpp.

42{
43 return _training_runtime_module.getInputSizeAt(position);
44}
uint32_t getInputSizeAt(uint32_t position)

References onert_micro::core::OMRuntimeModule::getInputSizeAt().

Referenced by entry(), training_configure_tool::runTrainProcessWithCurConfig(), and nnfw_session::train_run().

◆ getOutputDataAt()

void * OMTrainingInterpreter::getOutputDataAt ( uint32_t  position)

Definition at line 161 of file OMTrainingInterpreter.cpp.

162{
163 return _training_runtime_module.getOutputDataAt(position);
164}
void * getOutputDataAt(uint32_t position)

References onert_micro::core::OMRuntimeModule::getOutputDataAt().

Referenced by nnfw_session::train_run().

◆ getOutputSizeAt()

uint32_t OMTrainingInterpreter::getOutputSizeAt ( uint32_t  position)

Definition at line 46 of file OMTrainingInterpreter.cpp.

47{
48 return _training_runtime_module.getOutputSizeAt(position);
49}
uint32_t getOutputSizeAt(uint32_t position)

References onert_micro::core::OMRuntimeModule::getOutputSizeAt().

Referenced by entry(), training_configure_tool::runTrainProcessWithCurConfig(), and nnfw_session::train_run().

◆ importTrainModel()

OMStatus OMTrainingInterpreter::importTrainModel ( char *  model_ptr,
const OMConfig config 
)

Definition at line 25 of file OMTrainingInterpreter.cpp.

26{
27 assert(model_ptr != nullptr && "Model ptr shouldn't be nullptr");
28 if (model_ptr == nullptr)
29 return UnknownError;
30
31 return _training_runtime_module.importTrainModel(model_ptr, config);
32}
OMStatus importTrainModel(char *model_ptr, const OMConfig &config)

References onert_micro::core::OMTrainingRuntimeModule::importTrainModel(), and onert_micro::UnknownError.

Referenced by entry(), nnfw_session::load_model_from_file(), and training_configure_tool::runTrainProcessWithCurConfig().

◆ loadCheckpoint()

OMStatus OMTrainingInterpreter::loadCheckpoint ( OMConfig config,
const char *  load_path 
)

Definition at line 77 of file OMTrainingInterpreter.cpp.

78{
79 // Not imported or path is empty
80 if (load_path == nullptr or config.model_ptr == nullptr or config.model_size == 0)
81 return UnknownError;
82
83 // Get DataBuffer (vector of chars) of checkpoints
84 std::vector<char> checkpoint_data;
85
86 // Read data
87#ifndef DIS_STREAM
88 std::ifstream file(load_path, std::ios::binary | std::ios::in);
89 if (!file.good())
90 {
91 assert(false && "Fail to open");
92 return UnknownError;
93 }
94
95 file.seekg(0, std::ios::end);
96 auto fileSize = file.tellg();
97 file.seekg(0, std::ios::beg);
98
99 // reserve capacity
100 checkpoint_data.resize(fileSize);
101
102 // read the data
103 file.read(checkpoint_data.data(), fileSize);
104 if (file.fail())
105 {
106 assert(false && "Fail to read");
107 return UnknownError;
108 }
109#else
110 assert(false && "Not supported");
111 return UnknownError;
112#endif // DIS_STREAM
113
114 // Load data
115 OMStatus status = _training_runtime_module.loadCheckpointData(config, checkpoint_data.data());
116
117 return status;
118}
OMStatus loadCheckpointData(OMConfig &config, const char *data)

References onert_micro::core::OMTrainingRuntimeModule::loadCheckpointData(), and onert_micro::UnknownError.

Referenced by entry(), and nnfw_session::train_import_checkpoint().

◆ operator=() [1/2]

OMTrainingInterpreter && onert_micro::OMTrainingInterpreter::operator= ( const OMTrainingInterpreter &&  )
delete

◆ operator=() [2/2]

OMTrainingInterpreter & onert_micro::OMTrainingInterpreter::operator= ( const OMTrainingInterpreter )
delete

◆ reset()

OMStatus OMTrainingInterpreter::reset ( )

◆ run()

OMStatus onert_micro::OMTrainingInterpreter::run ( const OMConfig config)
inline

Definition at line 90 of file OMTrainingInterpreter.h.

90{ return _training_runtime_module.run(config); }
OMStatus run(const OMConfig &config)

References onert_micro::core::OMRuntimeModule::run().

Referenced by package.infer.session::inference(), and nnfw_session::train_run().

◆ saveCheckpoint()

OMStatus OMTrainingInterpreter::saveCheckpoint ( const OMConfig config,
const char *  save_path 
)

Definition at line 120 of file OMTrainingInterpreter.cpp.

121{
122 // Not imported or path is empty
123 if (save_path == nullptr or config.model_ptr == nullptr or config.model_size == 0)
124 return UnknownError;
125
126 // Get DataBuffer (vector of chars) of checkpoints
127 std::vector<char> checkpoint_data;
128
129 OMStatus status = _training_runtime_module.createCheckpointFile(config, checkpoint_data);
130
131 assert(status == Ok);
132 if (status != Ok)
133 return status;
134
135 // Save it into save_path
136#ifndef DIS_STREAM
137 // Open or create file
138 // Note: if the file existed, it will be overwritten
139 std::ofstream out_file(save_path, std::ios::binary | std::ios::trunc);
140 if (not out_file.is_open())
141 return UnknownError;
142
143 // Write data
144 out_file.write(checkpoint_data.data(), checkpoint_data.size());
145
146 // Close file
147 out_file.close();
148#else
149 assert(false && "Not supported");
150 return UnknownError;
151#endif // DIS_STREAM
152
153 return Ok;
154}
OMStatus createCheckpointFile(const OMConfig &config, std::vector< char > &data_buffer)

References onert_micro::core::OMTrainingRuntimeModule::createCheckpointFile(), onert_micro::Ok, and onert_micro::UnknownError.

Referenced by entry(), and nnfw_session::train_export_checkpoint().

◆ saveModel()

OMStatus OMTrainingInterpreter::saveModel ( const OMConfig config,
const char *  save_path 
)

Definition at line 51 of file OMTrainingInterpreter.cpp.

52{
53 if (save_path == nullptr or config.model_size == 0 or config.model_ptr == nullptr)
54 return UnknownError;
55
56#ifndef DIS_STREAM
57 // Open or create file
58 // Note: if the file existed, it will be overwritten
59 std::ofstream out_file(save_path, std::ios::binary | std::ios::trunc);
60 if (not out_file.is_open())
61 return UnknownError;
62
63 // Write data
64 out_file.write(config.model_ptr, config.model_size);
65
66 // Close file
67 out_file.close();
68#else
69 assert(false && "Not supported");
70 return UnknownError;
71#endif // DIS_STREAM
72
73 // Saving done
74 return Ok;
75}

References onert_micro::Ok, and onert_micro::UnknownError.

Referenced by entry(), and nnfw_session::train_export_circle().

◆ setInput()

void onert_micro::OMTrainingInterpreter::setInput ( uint8_t *  data,
uint32_t  input_index 
)
inline

Definition at line 47 of file OMTrainingInterpreter.h.

48 {
49 _training_runtime_module.setInputData(data, input_index);
50 }
void setInputData(uint8_t *data, uint32_t input_index)

References onert_micro::core::OMTrainingRuntimeModule::setInputData().

Referenced by entry(), training_configure_tool::runTrainProcessWithCurConfig(), and nnfw_session::train_set_input().

◆ setTarget()

void onert_micro::OMTrainingInterpreter::setTarget ( uint8_t *  data,
uint32_t  target_index 
)
inline

Definition at line 53 of file OMTrainingInterpreter.h.

54 {
55 _training_runtime_module.setTargetData(data, target_index);
56 }
void setTargetData(uint8_t *data, uint32_t target_index)

References onert_micro::core::OMTrainingRuntimeModule::setTargetData().

Referenced by entry(), training_configure_tool::runTrainProcessWithCurConfig(), and nnfw_session::train_set_expected().

◆ trainSingleStep()

OMStatus OMTrainingInterpreter::trainSingleStep ( OMConfig config)

The documentation for this class was generated from the following files: