ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::exec::train::TrainableFnSequence Class Reference

#include <TrainableFnSequence.h>

Public Member Functions

void forward (bool training)
 
void backward (uint32_t training_step, bool weight_update_enabled)
 
void append (std::unique_ptr< ITrainableFunction > &&fn)
 
void append (std::unique_ptr< IGradientApplier > &&applier)
 
void iterate (const std::function< void(ITrainableFunction &)> &fn)
 

Data Fields

std::vector< std::unique_ptr< ITrainableFunction > > _functions
 
std::vector< std::unique_ptr< IGradientApplier > > _appliers
 

Detailed Description

Definition at line 29 of file TrainableFnSequence.h.

Member Function Documentation

◆ append() [1/2]

void onert::exec::train::TrainableFnSequence::append ( std::unique_ptr< IGradientApplier > &&  applier)

Definition at line 50 of file TrainableFnSequence.cc.

51{
52 _appliers.push_back(std::move(applier));
53}
std::vector< std::unique_ptr< IGradientApplier > > _appliers

References _appliers.

◆ append() [2/2]

void onert::exec::train::TrainableFnSequence::append ( std::unique_ptr< ITrainableFunction > &&  fn)

Definition at line 45 of file TrainableFnSequence.cc.

46{
47 _functions.push_back(std::move(function));
48}
std::vector< std::unique_ptr< ITrainableFunction > > _functions

References _functions.

◆ backward()

void onert::exec::train::TrainableFnSequence::backward ( uint32_t  training_step,
bool  weight_update_enabled 
)

Definition at line 30 of file TrainableFnSequence.cc.

31{
32 for (auto it = _functions.rbegin(); it != _functions.rend(); ++it)
33 {
34 (*it)->backward();
35 }
36 if (weight_update_enabled)
37 {
38 for (const auto &applier : _appliers)
39 {
40 applier->applyGradient(training_step);
41 }
42 }
43}

References _appliers, and _functions.

◆ forward()

void onert::exec::train::TrainableFnSequence::forward ( bool  training)

Definition at line 22 of file TrainableFnSequence.cc.

23{
24 for (const auto &function : _functions)
25 {
26 function->forward(training);
27 }
28}

References _functions.

◆ iterate()

void onert::exec::train::TrainableFnSequence::iterate ( const std::function< void(ITrainableFunction &)> &  fn)

Definition at line 55 of file TrainableFnSequence.cc.

56{
57 for (const auto &func : _functions)
58 {
59 fn(*func);
60 }
61}

References _functions.

Field Documentation

◆ _appliers

std::vector<std::unique_ptr<IGradientApplier> > onert::exec::train::TrainableFnSequence::_appliers

Definition at line 42 of file TrainableFnSequence.h.

Referenced by append(), and backward().

◆ _functions

std::vector<std::unique_ptr<ITrainableFunction> > onert::exec::train::TrainableFnSequence::_functions

Definition at line 41 of file TrainableFnSequence.h.

Referenced by append(), backward(), forward(), and iterate().


The documentation for this class was generated from the following files: