30 : _builder(1024)
31 {
32 const auto &optimizerInfo = training_info->optimizerInfo();
33 const auto &lossInfo = training_info->lossInfo();
34
36 ::circle::OptimizerOptions optimizer_opt_type;
38 switch (optimizerInfo.optim_code)
39 {
42 optimizer_opt_type = ::circle::OptimizerOptions_SGDOptions;
43 optimizer_opt = ::circle::CreateSGDOptions(_builder, optimizerInfo.learning_rate).
Union();
44 break;
47 optimizer_opt_type = ::circle::OptimizerOptions_AdamOptions;
48 optimizer_opt = ::circle::CreateAdamOptions(_builder, optimizerInfo.learning_rate).
Union();
49 break;
50 default:
51 throw std::runtime_error("Not supported optimizer code");
52 }
53
54 ::circle::LossFn lossfn;
55 ::circle::LossFnOptions lossfn_opt_type;
57 switch (lossInfo.loss_code)
58 {
60 lossfn = ::circle::LossFn_MEAN_SQUARED_ERROR;
61 lossfn_opt_type = ::circle::LossFnOptions_MeanSquaredErrorOptions;
62 lossfn_opt = ::circle::CreateMeanSquaredErrorOptions(_builder).
Union();
63 break;
65 lossfn = ::circle::LossFn_CATEGORICAL_CROSSENTROPY;
66 lossfn_opt_type = ::circle::LossFnOptions_CategoricalCrossentropyOptions;
67 lossfn_opt = ::circle::CreateCategoricalCrossentropyOptions(_builder).
Union();
68 break;
69 default:
70 throw std::runtime_error("Not supported loss code");
71 }
72
73 ::circle::LossReductionType loss_reduction_type;
74 switch (lossInfo.reduction_type)
75 {
77 loss_reduction_type = ::circle::LossReductionType_SumOverBatchSize;
78 break;
80 loss_reduction_type = ::circle::LossReductionType_Sum;
81 break;
82 default:
83 throw std::runtime_error("Not supported loss reduction type");
84 }
85
86 std::vector<int32_t> trainable_ops;
87 for (const auto &op : training_info->getTrainableOps())
88 {
89 trainable_ops.push_back(op.value());
90 }
91
92 const auto end = ::circle::CreateModelTrainingDirect(
93 _builder, training_info->version(),
optimizer, optimizer_opt_type, optimizer_opt, lossfn,
94 lossfn_opt_type, lossfn_opt, 0, training_info->batchSize(), loss_reduction_type,
95 &trainable_ops);
96 _builder.
Finish(end, ::circle::ModelTrainingIdentifier());
97
99 bool verified = ::circle::VerifyModelTrainingBuffer(v);
100 if (not verified)
101 throw std::runtime_error{"TrainingInfo buffer is not accessible"};
102 }
void Finish(Offset< T > root, const char *file_identifier=nullptr)
Finish serializing a buffer by writing the root offset.
uoffset_t GetSize() const
The current size of the serialized buffer, counting from the end.
uint8_t * GetBufferPointer() const
Get the serialized buffer (after you call Finish()).
ShapeIterator end(const Shape &s)
@ CategoricalCrossentropy
Offset< void > Union() const