ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::exec::train::TrainableFnSequence Class Reference

#include <TrainableFnSequence.h>

Public Member Functions

void forward (bool training)
 
void backward (uint32_t training_step, bool weight_update_enabled)
 
void append (std::unique_ptr< ITrainableFunction > &&fn)
 
void append (std::unique_ptr< IGradientApplier > &&applier)
 
void iterate (const std::function< void(ITrainableFunction &)> &fn)
 

Data Fields

std::vector< std::unique_ptr< ITrainableFunction > > _functions
 
std::vector< std::unique_ptr< IGradientApplier > > _appliers
 

Detailed Description

Definition at line 33 of file TrainableFnSequence.h.

Member Function Documentation

◆ append() [1/2]

void onert::exec::train::TrainableFnSequence::append ( std::unique_ptr< IGradientApplier > &&  applier)

Definition at line 54 of file TrainableFnSequence.cc.

55{
56 _appliers.push_back(std::move(applier));
57}
std::vector< std::unique_ptr< IGradientApplier > > _appliers

References _appliers.

◆ append() [2/2]

void onert::exec::train::TrainableFnSequence::append ( std::unique_ptr< ITrainableFunction > &&  fn)

Definition at line 49 of file TrainableFnSequence.cc.

50{
51 _functions.push_back(std::move(function));
52}
std::vector< std::unique_ptr< ITrainableFunction > > _functions

References _functions.

◆ backward()

void onert::exec::train::TrainableFnSequence::backward ( uint32_t  training_step,
bool  weight_update_enabled 
)

Definition at line 34 of file TrainableFnSequence.cc.

35{
36 for (auto it = _functions.rbegin(); it != _functions.rend(); ++it)
37 {
38 (*it)->backward();
39 }
40 if (weight_update_enabled)
41 {
42 for (const auto &applier : _appliers)
43 {
44 applier->applyGradient(training_step);
45 }
46 }
47}

References _appliers, and _functions.

◆ forward()

void onert::exec::train::TrainableFnSequence::forward ( bool  training)

Definition at line 26 of file TrainableFnSequence.cc.

27{
28 for (const auto &function : _functions)
29 {
30 function->forward(training);
31 }
32}

References _functions.

◆ iterate()

void onert::exec::train::TrainableFnSequence::iterate ( const std::function< void(ITrainableFunction &)> &  fn)

Definition at line 59 of file TrainableFnSequence.cc.

60{
61 for (const auto &func : _functions)
62 {
63 fn(*func);
64 }
65}

References _functions.

Field Documentation

◆ _appliers

std::vector<std::unique_ptr<IGradientApplier> > onert::exec::train::TrainableFnSequence::_appliers

Definition at line 46 of file TrainableFnSequence.h.

Referenced by append(), and backward().

◆ _functions

std::vector<std::unique_ptr<ITrainableFunction> > onert::exec::train::TrainableFnSequence::_functions

Definition at line 45 of file TrainableFnSequence.h.

Referenced by append(), backward(), forward(), and iterate().


The documentation for this class was generated from the following files: