ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::exporter::TrainInfoBuilder Class Reference

#include <TrainInfoBuilder.h>

Public Member Functions

 TrainInfoBuilder (const std::unique_ptr< ir::train::TrainingInfo > &training_info)
 
uint8_t * get () const
 
uint32_t size () const
 

Detailed Description

Definition at line 27 of file TrainInfoBuilder.h.

Constructor & Destructor Documentation

◆ TrainInfoBuilder()

onert::exporter::TrainInfoBuilder::TrainInfoBuilder ( const std::unique_ptr< ir::train::TrainingInfo > &  training_info)
inline

Definition at line 30 of file TrainInfoBuilder.h.

30 : _builder(1024)
31 {
32 const auto &optimizerInfo = training_info->optimizerInfo();
33 const auto &lossInfo = training_info->lossInfo();
34
35 ::circle::Optimizer optimizer;
36 ::circle::OptimizerOptions optimizer_opt_type;
37 ::flatbuffers::Offset<void> optimizer_opt;
38 switch (optimizerInfo.optim_code)
39 {
41 optimizer = ::circle::Optimizer_SGD;
42 optimizer_opt_type = ::circle::OptimizerOptions_SGDOptions;
43 optimizer_opt = ::circle::CreateSGDOptions(_builder, optimizerInfo.learning_rate).Union();
44 break;
46 optimizer = ::circle::Optimizer_ADAM;
47 optimizer_opt_type = ::circle::OptimizerOptions_AdamOptions;
48 optimizer_opt = ::circle::CreateAdamOptions(_builder, optimizerInfo.learning_rate).Union();
49 break;
50 default:
51 throw std::runtime_error("Not supported optimizer code");
52 }
53
54 ::circle::LossFn lossfn;
55 ::circle::LossFnOptions lossfn_opt_type;
57 switch (lossInfo.loss_code)
58 {
60 lossfn = ::circle::LossFn_MEAN_SQUARED_ERROR;
61 lossfn_opt_type = ::circle::LossFnOptions_MeanSquaredErrorOptions;
62 lossfn_opt = ::circle::CreateMeanSquaredErrorOptions(_builder).Union();
63 break;
65 lossfn = ::circle::LossFn_CATEGORICAL_CROSSENTROPY;
66 lossfn_opt_type = ::circle::LossFnOptions_CategoricalCrossentropyOptions;
67 lossfn_opt = ::circle::CreateCategoricalCrossentropyOptions(_builder).Union();
68 break;
69 default:
70 throw std::runtime_error("Not supported loss code");
71 }
72
73 ::circle::LossReductionType loss_reduction_type;
74 switch (lossInfo.reduction_type)
75 {
77 loss_reduction_type = ::circle::LossReductionType_SumOverBatchSize;
78 break;
80 loss_reduction_type = ::circle::LossReductionType_Sum;
81 break;
82 default:
83 throw std::runtime_error("Not supported loss reduction type");
84 }
85
86 std::vector<int32_t> trainable_ops;
87 for (const auto &op : training_info->getTrainableOps())
88 {
89 trainable_ops.push_back(op.value());
90 }
91
92 const auto end = ::circle::CreateModelTrainingDirect(
93 _builder, training_info->version(), optimizer, optimizer_opt_type, optimizer_opt, lossfn,
94 lossfn_opt_type, lossfn_opt, 0, training_info->batchSize(), loss_reduction_type,
95 &trainable_ops);
96 _builder.Finish(end, ::circle::ModelTrainingIdentifier());
97
98 ::flatbuffers::Verifier v(_builder.GetBufferPointer(), _builder.GetSize());
99 bool verified = ::circle::VerifyModelTrainingBuffer(v);
100 if (not verified)
101 throw std::runtime_error{"TrainingInfo buffer is not accessible"};
102 }
void Finish(Offset< T > root, const char *file_identifier=nullptr)
Finish serializing a buffer by writing the root offset.
uoffset_t GetSize() const
The current size of the serialized buffer, counting from the end.
uint8_t * GetBufferPointer() const
Get the serialized buffer (after you call Finish()).
ShapeIterator end(const Shape &s)
Offset< void > Union() const
Definition flatbuffers.h:74

References onert::ir::train::Adam, onert::ir::train::CategoricalCrossentropy, flatbuffers::FlatBufferBuilder::Finish(), flatbuffers::FlatBufferBuilder::GetBufferPointer(), flatbuffers::FlatBufferBuilder::GetSize(), onert::ir::train::MeanSquaredError, onert::ir::train::SGD, onert::ir::train::Sum, onert::ir::train::SumOverBatchSize, and flatbuffers::Offset< T >::Union().

Member Function Documentation

◆ get()

uint8_t * onert::exporter::TrainInfoBuilder::get ( ) const
inline

◆ size()

uint32_t onert::exporter::TrainInfoBuilder::size ( ) const
inline

Definition at line 105 of file TrainInfoBuilder.h.

105{ return _builder.GetSize(); }

References flatbuffers::FlatBufferBuilder::GetSize().

Referenced by onert::exporter::CircleExporter::updateMetadata().


The documentation for this class was generated from the following file: