ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::train::ops::GradientApplier Class Reference

#include <GradientApplier.h>

Collaboration diagram for onert::backend::train::ops::GradientApplier:

Public Member Functions

 GradientApplier ()
 
 ~GradientApplier ()=default
 
void configure (const exec::train::optimizer::Optimizer *optimizer, const IPortableTensor *gradient, ITrainableTensor *trainable)
 
void applyGradient (uint32_t training_step) override
 Apply gradients to a trainable tensor.
 
- Public Member Functions inherited from onert::exec::train::IGradientApplier
virtual ~IGradientApplier ()=default
 

Detailed Description

Definition at line 33 of file GradientApplier.h.

Constructor & Destructor Documentation

◆ GradientApplier()

onert::backend::train::ops::GradientApplier::GradientApplier ( )

Definition at line 30 of file GradientApplier.cc.

30 : _optimizer{nullptr}, _gradient_tensor{}, _trainable_tensor{}
31{
32 // DO NOTHING
33}

◆ ~GradientApplier()

onert::backend::train::ops::GradientApplier::~GradientApplier ( )
default

Member Function Documentation

◆ applyGradient()

void onert::backend::train::ops::GradientApplier::applyGradient ( uint32_t  training_step)
overridevirtual

Apply gradients to a trainable tensor.

Parameters
training_stepThe number of iterations of the training process.

Implements onert::exec::train::IGradientApplier.

Definition at line 43 of file GradientApplier.cc.

44{
45 _optimizer->applyGradient(
46 std::forward_as_tuple(*_gradient_tensor, *_trainable_tensor, training_step));
47}
virtual void applyGradient(const UpdateFactors &factors) const =0
Apply gradient to a trainable tensor.

References onert::exec::train::optimizer::Optimizer::applyGradient().

◆ configure()

void onert::backend::train::ops::GradientApplier::configure ( const exec::train::optimizer::Optimizer optimizer,
const IPortableTensor gradient,
ITrainableTensor trainable 
)

Definition at line 35 of file GradientApplier.cc.

37{
38 _optimizer = optimizer;
39 _gradient_tensor = gradient;
40 _trainable_tensor = trainable;
41}

The documentation for this class was generated from the following files: