ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::train::optimizer::SGD Class Reference

SGD optimizer. More...

#include <SGD.h>

Collaboration diagram for onert::backend::train::optimizer::SGD:

Data Structures

struct  Property
 

Public Types

using UpdateFactors = exec::train::optimizer::UpdateFactors
 

Public Member Functions

 SGD ()
 
 SGD (const Property &props)
 
 SGD (double lr)
 
 SGD (const Property &props, double lr)
 
std::string name () const override
 Get the name of optimizer.
 
double getLearningRate (uint32_t iteration=0) const override
 Get the Learning Rate.
 
virtual uint32_t getVarCount () const override
 Get the number of optimizer variables s.
 
void applyGradient (const UpdateFactors &factors) const override
 Apply gradient to a trainable tensor.
 
- Public Member Functions inherited from onert::exec::train::optimizer::Optimizer
virtual ~Optimizer ()=default
 

Detailed Description

SGD optimizer.

Definition at line 35 of file SGD.h.

Member Typedef Documentation

◆ UpdateFactors

Constructor & Destructor Documentation

◆ SGD() [1/4]

onert::backend::train::optimizer::SGD::SGD ( )
inlineexplicit

Definition at line 48 of file SGD.h.

48: _props{}, _learning_rate{0.01} {}

◆ SGD() [2/4]

onert::backend::train::optimizer::SGD::SGD ( const Property props)
inlineexplicit

Definition at line 49 of file SGD.h.

49: _props{props}, _learning_rate{0.01} {}

◆ SGD() [3/4]

onert::backend::train::optimizer::SGD::SGD ( double  lr)
inlineexplicit

Definition at line 50 of file SGD.h.

50: _props{}, _learning_rate{lr} {}

◆ SGD() [4/4]

onert::backend::train::optimizer::SGD::SGD ( const Property props,
double  lr 
)
inlineexplicit

Definition at line 51 of file SGD.h.

51: _props{props}, _learning_rate{lr} {}

Member Function Documentation

◆ applyGradient()

void onert::backend::train::optimizer::SGD::applyGradient ( const UpdateFactors factors) const
overridevirtual

Apply gradient to a trainable tensor.

Parameters
factorsUpdateFactors to be used for applying gradient to a trainable tensor

Implements onert::exec::train::optimizer::Optimizer.

Definition at line 38 of file SGD.cc.

39{
40 auto [grad_tensor, trainable_tensor, training_step] = factors;
41 assert(trainable_tensor.data_type() == grad_tensor.data_type());
42
43 if (trainable_tensor.getShape() != grad_tensor.getShape())
44 {
45 throw std::runtime_error("SGD: Invalid gradient tensor");
46 }
47
48 const auto lr = getLearningRate(training_step);
49 switch (grad_tensor.data_type())
50 {
51 case ir::DataType::FLOAT32:
53 ops::getShape(&trainable_tensor), ops::getBuffer<float>(&trainable_tensor),
54 ops::getShape(&grad_tensor), ops::getBuffer<float>(&grad_tensor), lr);
55 break;
56 default:
57 throw std::runtime_error("SGD: Not supported data type");
58 }
59}
double getLearningRate(uint32_t iteration=0) const override
Get the Learning Rate.
Definition SGD.cc:32
void GradientDescent(const Shape &output_shape, float *output_data, const Shape &grad_shape, const float *grad_data, float learning_rate)
Definition SGD.h:33

References getLearningRate(), and nnfw::cker::train::GradientDescent().

◆ getLearningRate()

double onert::backend::train::optimizer::SGD::getLearningRate ( uint32_t  iteration = 0) const
overridevirtual

Get the Learning Rate.

Parameters
iterationThe number of training steps
Returns
Learning rate

Implements onert::exec::train::optimizer::Optimizer.

Definition at line 32 of file SGD.cc.

33{
34 // TODO Use iteration, momentum, and nesterov
35 return _learning_rate;
36}

Referenced by applyGradient().

◆ getVarCount()

virtual uint32_t onert::backend::train::optimizer::SGD::getVarCount ( ) const
inlineoverridevirtual

Get the number of optimizer variables s.

Returns
The number of optimizer variables

Implements onert::exec::train::optimizer::Optimizer.

Definition at line 74 of file SGD.h.

74{ return 0; };

◆ name()

std::string onert::backend::train::optimizer::SGD::name ( ) const
inlineoverridevirtual

Get the name of optimizer.

Returns
The name of optimizer

Reimplemented from onert::exec::train::optimizer::Optimizer.

Definition at line 59 of file SGD.h.

59{ return std::string{"SGD"}; }

The documentation for this class was generated from the following files: