ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::train::ops::GradientApplier Class Reference

#include <GradientApplier.h>

Collaboration diagram for onert::backend::train::ops::GradientApplier:

Public Member Functions

 GradientApplier ()
 
 ~GradientApplier ()=default
 
void configure (const exec::train::optimizer::Optimizer *optimizer, const IPortableTensor *gradient, ITrainableTensor *trainable)
 
void applyGradient (uint32_t training_step) override
 Apply gradients to a trainable tensor.
 
- Public Member Functions inherited from onert::exec::train::IGradientApplier
virtual ~IGradientApplier ()=default
 

Detailed Description

Definition at line 27 of file GradientApplier.h.

Constructor & Destructor Documentation

◆ GradientApplier()

onert::backend::train::ops::GradientApplier::GradientApplier ( )

Definition at line 24 of file GradientApplier.cc.

24 : _optimizer{nullptr}, _gradient_tensor{}, _trainable_tensor{}
25{
26 // DO NOTHING
27}

◆ ~GradientApplier()

onert::backend::train::ops::GradientApplier::~GradientApplier ( )
default

Member Function Documentation

◆ applyGradient()

void onert::backend::train::ops::GradientApplier::applyGradient ( uint32_t  training_step)
overridevirtual

Apply gradients to a trainable tensor.

Parameters
training_stepThe number of iterations of the training process.

Implements onert::exec::train::IGradientApplier.

Definition at line 37 of file GradientApplier.cc.

38{
39 _optimizer->applyGradient(
40 std::forward_as_tuple(*_gradient_tensor, *_trainable_tensor, training_step));
41}
virtual void applyGradient(const UpdateFactors &factors) const =0
Apply gradient to a trainable tensor.

References onert::exec::train::optimizer::Optimizer::applyGradient().

◆ configure()

void onert::backend::train::ops::GradientApplier::configure ( const exec::train::optimizer::Optimizer optimizer,
const IPortableTensor gradient,
ITrainableTensor trainable 
)

Definition at line 29 of file GradientApplier.cc.

31{
32 _optimizer = optimizer;
33 _gradient_tensor = gradient;
34 _trainable_tensor = trainable;
35}

The documentation for this class was generated from the following files: