ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor > Class Template Reference

#include <ITensorRegistry.h>

Collaboration diagram for onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >:

Public Types

using TrainingTensors = std::tuple< TrainableTensor *, GradientTensor * >
 

Public Member Functions

ITensorgetITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor among native and migrant tensors.
 
ITensorgetNativeITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor among native tensors.
 
ITensorgetBackPropITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor for back propatation.
 
ITensorgetGradientITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor for gradient.
 
void iterateTrainableTensors (const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &fn) const override
 Iterate ITrainableTensors with fn.
 
IPortableTensorgetPortableTensor (const ir::OperandIndex &index)
 
TensorgetNonConstTensor (const ir::OperandIndex &index)
 
TrainableTensorgetTrainableTensor (const ir::OperandIndex &index)
 
BackPropTensorgetBackPropTensor (const ir::OperandIndex &index)
 
GradientTensorgetGradientTensor (const ir::OperandIndex &index)
 
TrainingTensors getTrainingTensors (const ir::OperandIndex &index)
 
bool setMigrantTensor (const ir::OperandIndex &index, IPortableTensor *tensor) override
 Set the Migrant Tensor which are from other backends.
 
void setNonConstTensor (const ir::OperandIndex &index, std::unique_ptr< Tensor > tensor)
 
void setTrainableTensor (const ir::OperandIndex &index, std::unique_ptr< TrainableTensor > tensor)
 
void setBackPropTensor (const ir::OperandIndex &index, std::unique_ptr< BackPropTensor > tensor)
 
void setGradientTensor (const ir::OperandIndex &index, std::unique_ptr< GradientTensor > tensor)
 
const ir::OperandIndexMap< std::unique_ptr< TrainableTensor > > & trainable_tensors ()
 
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & nonconst_tensors ()
 
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & back_prop_tensors ()
 
const ir::OperandIndexMap< std::unique_ptr< GradientTensor > > & gradient_tensors ()
 
- Public Member Functions inherited from onert::backend::ITensorRegistry
virtual ~ITensorRegistry ()=default
 Deconstruct itself.
 

Detailed Description

template<typename Tensor, typename TrainableTensor, typename BackPropTensor, typename GradientTensor>
class onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >

Definition at line 72 of file ITensorRegistry.h.

Member Typedef Documentation

◆ TrainingTensors

Definition at line 75 of file ITensorRegistry.h.

Member Function Documentation

◆ back_prop_tensors()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::back_prop_tensors ( )
inline

Definition at line 229 of file ITensorRegistry.h.

229{ return _back_prop; }

◆ getBackPropITensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
ITensor * onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getBackPropITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor for back propatation.

Note
Return tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::train::ITensorRegistry.

Definition at line 94 of file ITensorRegistry.h.

95 {
96 return getBackPropTensor(index);
97 }
BackPropTensor * getBackPropTensor(const ir::OperandIndex &index)

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getBackPropTensor().

◆ getBackPropTensor()

Definition at line 140 of file ITensorRegistry.h.

141 {
142 auto tensor = _back_prop.find(index);
143 if (tensor != _back_prop.end())
144 return tensor->second.get();
145 return nullptr;
146 }

Referenced by onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getBackPropITensor().

◆ getGradientITensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
ITensor * onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getGradientITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor for gradient.

Note
Returned tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::train::ITensorRegistry.

Definition at line 99 of file ITensorRegistry.h.

100 {
101 return getGradientTensor(index);
102 }
GradientTensor * getGradientTensor(const ir::OperandIndex &index)

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getGradientTensor().

◆ getGradientTensor()

◆ getITensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
ITensor * onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor among native and migrant tensors.

Native Tensor is a tensor that is managed by this backend Migrant Tensor is a tensor that is imported from another backend

Note
Return tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::ITensorRegistry.

Definition at line 78 of file ITensorRegistry.h.

79 {
80 auto _migrant_tensor = _migrant.find(index);
81 if (_migrant_tensor != _migrant.end())
82 return _migrant_tensor->second;
83 return getNativeITensor(index);
84 }
ITensor * getNativeITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native tensors.

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getNativeITensor().

Referenced by onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setMigrantTensor(), onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setNonConstTensor(), and onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setTrainableTensor().

◆ getNativeITensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
ITensor * onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getNativeITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor among native tensors.

Unlike getITensor , this function only searches from native tensors

Note
Returned tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::ITensorRegistry.

Definition at line 86 of file ITensorRegistry.h.

87 {
88 ITensor *tensor = getTrainableTensor(index);
89 if (tensor == nullptr)
91 return tensor;
92 }
Tensor * getNonConstTensor(const ir::OperandIndex &index)
TrainableTensor * getTrainableTensor(const ir::OperandIndex &index)

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getNonConstTensor(), and onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getTrainableTensor().

Referenced by onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getITensor().

◆ getNonConstTensor()

◆ getPortableTensor()

Definition at line 112 of file ITensorRegistry.h.

113 {
114 auto tensor = _trainable.find(index);
115 if (tensor != _trainable.end())
116 {
117 if (tensor->second)
118 return tensor->second.get();
119 }
120 return getNonConstTensor(index);
121 }

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getNonConstTensor().

◆ getTrainableTensor()

◆ getTrainingTensors()

Definition at line 156 of file ITensorRegistry.h.

157 {
158 auto trainable = getTrainableTensor(index);
159 if (trainable == nullptr)
160 throw std::runtime_error{
161 "Tried to get a trainable tensor but the corresponding tensor does not exist."};
162
163 auto gradient = getGradientTensor(index);
164 if (gradient == nullptr)
165 throw std::runtime_error{
166 "Tried to get a gradient tensor but the corresponding tensor does not exist."};
167
168 return TrainingTensors{std::make_pair(trainable, gradient)};
169 }
std::tuple< TrainableTensor *, GradientTensor * > TrainingTensors

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getGradientTensor(), and onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getTrainableTensor().

◆ gradient_tensors()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
const ir::OperandIndexMap< std::unique_ptr< GradientTensor > > & onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::gradient_tensors ( )
inline

Definition at line 230 of file ITensorRegistry.h.

231 {
232 return _gradient;
233 }

◆ iterateTrainableTensors()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
void onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::iterateTrainableTensors ( const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &  ) const
inlineoverridevirtual

Iterate ITrainableTensors with fn.

Parameters
fnfunction to be called with OperandIndex and a pointer to ITrainableTensor

Implements onert::backend::train::ITensorRegistry.

Definition at line 104 of file ITensorRegistry.h.

107 {
108 for (const auto &[index, tensor] : _trainable)
109 fn(index, tensor.get());
110 }
KnobTrait< K >::ValueType get(void)
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

◆ nonconst_tensors()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::nonconst_tensors ( )
inline

Definition at line 228 of file ITensorRegistry.h.

228{ return _non_const; }

◆ setBackPropTensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
void onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setBackPropTensor ( const ir::OperandIndex index,
std::unique_ptr< BackPropTensor tensor 
)
inline

Definition at line 202 of file ITensorRegistry.h.

203 {
204 assert(tensor != nullptr);
205 auto itr = _back_prop.find(index);
206 if (itr != _back_prop.end())
207 throw std::runtime_error{"Tried to set a back propagation tensor but another back "
208 "propagation tensor already exists."};
209
210 _back_prop[index] = std::move(tensor);
211 }

◆ setGradientTensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
void onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setGradientTensor ( const ir::OperandIndex index,
std::unique_ptr< GradientTensor tensor 
)
inline

Definition at line 213 of file ITensorRegistry.h.

214 {
215 assert(tensor != nullptr);
216 auto itr = _gradient.find(index);
217 if (itr != _gradient.end())
218 throw std::runtime_error{
219 "Tried to set a gradient tensor but another gradient tensor already exists."};
220
221 _gradient[index] = std::move(tensor);
222 }

◆ setMigrantTensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
bool onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setMigrantTensor ( const ir::OperandIndex ,
IPortableTensor  
)
inlineoverridevirtual

Set the Migrant Tensor which are from other backends.

Returns
true if supported
false if not supported

Reimplemented from onert::backend::ITensorRegistry.

Definition at line 171 of file ITensorRegistry.h.

172 {
173 assert(tensor != nullptr);
174 if (getITensor(index) != nullptr)
175 throw std::runtime_error{
176 "Tried to set a trainable tensor but another tensor already exists."};
177
178 _migrant[index] = tensor;
179 return true;
180 }
ITensor * getITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native and migrant tensors.

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getITensor().

◆ setNonConstTensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
void onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setNonConstTensor ( const ir::OperandIndex index,
std::unique_ptr< Tensor tensor 
)
inline

Definition at line 182 of file ITensorRegistry.h.

183 {
184 assert(tensor != nullptr);
185 if (getITensor(index) != nullptr)
186 throw std::runtime_error{
187 "Tried to set a trainable tensor but another tensor already exists."};
188
189 _non_const[index] = std::move(tensor);
190 }

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getITensor().

◆ setTrainableTensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
void onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setTrainableTensor ( const ir::OperandIndex index,
std::unique_ptr< TrainableTensor tensor 
)
inline

Definition at line 192 of file ITensorRegistry.h.

193 {
194 assert(tensor != nullptr);
195 if (getITensor(index) != nullptr)
196 throw std::runtime_error{
197 "Tried to set a trainable tensor but another tensor already exists."};
198
199 _trainable[index] = std::move(tensor);
200 }

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getITensor().

◆ trainable_tensors()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
const ir::OperandIndexMap< std::unique_ptr< TrainableTensor > > & onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::trainable_tensors ( )
inline

Definition at line 224 of file ITensorRegistry.h.

225 {
226 return _trainable;
227 }

The documentation for this class was generated from the following file: