ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor > Class Template Reference

#include <ITensorRegistry.h>

Collaboration diagram for onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >:

Public Types

using TrainingTensors = std::tuple< TrainableTensor *, GradientTensor * >
 

Public Member Functions

ITensorgetITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor among native and migrant tensors.
 
ITensorgetNativeITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor among native tensors.
 
ITensorgetBackPropITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor for back propatation.
 
ITensorgetGradientITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor for gradient.
 
void iterateTrainableTensors (const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &fn) const override
 Iterate ITrainableTensors with fn.
 
IPortableTensorgetPortableTensor (const ir::OperandIndex &index)
 
TensorgetNonConstTensor (const ir::OperandIndex &index)
 
TrainableTensorgetTrainableTensor (const ir::OperandIndex &index)
 
BackPropTensorgetBackPropTensor (const ir::OperandIndex &index)
 
GradientTensorgetGradientTensor (const ir::OperandIndex &index)
 
TrainingTensors getTrainingTensors (const ir::OperandIndex &index)
 
bool setMigrantTensor (const ir::OperandIndex &index, IPortableTensor *tensor) override
 Set the Migrant Tensor which are from other backends.
 
void setNonConstTensor (const ir::OperandIndex &index, std::unique_ptr< Tensor > tensor)
 
void setTrainableTensor (const ir::OperandIndex &index, std::unique_ptr< TrainableTensor > tensor)
 
void setBackPropTensor (const ir::OperandIndex &index, std::unique_ptr< BackPropTensor > tensor)
 
void setGradientTensor (const ir::OperandIndex &index, std::unique_ptr< GradientTensor > tensor)
 
const ir::OperandIndexMap< std::unique_ptr< TrainableTensor > > & trainable_tensors ()
 
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & nonconst_tensors ()
 
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & back_prop_tensors ()
 
const ir::OperandIndexMap< std::unique_ptr< GradientTensor > > & gradient_tensors ()
 
- Public Member Functions inherited from onert::backend::ITensorRegistry
virtual ~ITensorRegistry ()=default
 Deconstruct itself.
 

Detailed Description

template<typename Tensor, typename TrainableTensor, typename BackPropTensor, typename GradientTensor>
class onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >

Definition at line 82 of file ITensorRegistry.h.

Member Typedef Documentation

◆ TrainingTensors

Definition at line 85 of file ITensorRegistry.h.

Member Function Documentation

◆ back_prop_tensors()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::back_prop_tensors ( )
inline

Definition at line 239 of file ITensorRegistry.h.

239{ return _back_prop; }

◆ getBackPropITensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
ITensor * onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getBackPropITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor for back propatation.

Note
Return tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::train::ITensorRegistry.

Definition at line 104 of file ITensorRegistry.h.

105 {
106 return getBackPropTensor(index);
107 }
BackPropTensor * getBackPropTensor(const ir::OperandIndex &index)

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getBackPropTensor().

◆ getBackPropTensor()

Definition at line 150 of file ITensorRegistry.h.

151 {
152 auto tensor = _back_prop.find(index);
153 if (tensor != _back_prop.end())
154 return tensor->second.get();
155 return nullptr;
156 }

Referenced by onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getBackPropITensor().

◆ getGradientITensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
ITensor * onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getGradientITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor for gradient.

Note
Returned tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::train::ITensorRegistry.

Definition at line 109 of file ITensorRegistry.h.

110 {
111 return getGradientTensor(index);
112 }
GradientTensor * getGradientTensor(const ir::OperandIndex &index)

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getGradientTensor().

◆ getGradientTensor()

◆ getITensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
ITensor * onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor among native and migrant tensors.

Native Tensor is a tensor that is managed by this backend Migrant Tensor is a tensor that is imported from another backend

Note
Return tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::ITensorRegistry.

Definition at line 88 of file ITensorRegistry.h.

89 {
90 auto _migrant_tensor = _migrant.find(index);
91 if (_migrant_tensor != _migrant.end())
92 return _migrant_tensor->second;
93 return getNativeITensor(index);
94 }
ITensor * getNativeITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native tensors.

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getNativeITensor().

Referenced by onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setMigrantTensor(), onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setNonConstTensor(), and onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setTrainableTensor().

◆ getNativeITensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
ITensor * onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getNativeITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor among native tensors.

Unlike getITensor , this function only searches from native tensors

Note
Returned tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::ITensorRegistry.

Definition at line 96 of file ITensorRegistry.h.

97 {
98 ITensor *tensor = getTrainableTensor(index);
99 if (tensor == nullptr)
100 tensor = getNonConstTensor(index);
101 return tensor;
102 }
Tensor * getNonConstTensor(const ir::OperandIndex &index)
TrainableTensor * getTrainableTensor(const ir::OperandIndex &index)

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getNonConstTensor(), and onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getTrainableTensor().

Referenced by onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getITensor().

◆ getNonConstTensor()

◆ getPortableTensor()

Definition at line 122 of file ITensorRegistry.h.

123 {
124 auto tensor = _trainable.find(index);
125 if (tensor != _trainable.end())
126 {
127 if (tensor->second)
128 return tensor->second.get();
129 }
130 return getNonConstTensor(index);
131 }

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getNonConstTensor().

◆ getTrainableTensor()

◆ getTrainingTensors()

Definition at line 166 of file ITensorRegistry.h.

167 {
168 auto trainable = getTrainableTensor(index);
169 if (trainable == nullptr)
170 throw std::runtime_error{
171 "Tried to get a trainable tensor but the corresponding tensor does not exist."};
172
173 auto gradient = getGradientTensor(index);
174 if (gradient == nullptr)
175 throw std::runtime_error{
176 "Tried to get a gradient tensor but the corresponding tensor does not exist."};
177
178 return TrainingTensors{std::make_pair(trainable, gradient)};
179 }
std::tuple< TrainableTensor *, GradientTensor * > TrainingTensors

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getGradientTensor(), and onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getTrainableTensor().

◆ gradient_tensors()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
const ir::OperandIndexMap< std::unique_ptr< GradientTensor > > & onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::gradient_tensors ( )
inline

Definition at line 240 of file ITensorRegistry.h.

241 {
242 return _gradient;
243 }

◆ iterateTrainableTensors()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
void onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::iterateTrainableTensors ( const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &  ) const
inlineoverridevirtual

Iterate ITrainableTensors with fn.

Parameters
fnfunction to be called with OperandIndex and a pointer to ITrainableTensor

Implements onert::backend::train::ITensorRegistry.

Definition at line 114 of file ITensorRegistry.h.

117 {
118 for (const auto &[index, tensor] : _trainable)
119 fn(index, tensor.get());
120 }
KnobTrait< K >::ValueType get(void)
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

◆ nonconst_tensors()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::nonconst_tensors ( )
inline

Definition at line 238 of file ITensorRegistry.h.

238{ return _non_const; }

◆ setBackPropTensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
void onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setBackPropTensor ( const ir::OperandIndex index,
std::unique_ptr< BackPropTensor tensor 
)
inline

Definition at line 212 of file ITensorRegistry.h.

213 {
214 assert(tensor != nullptr);
215 auto itr = _back_prop.find(index);
216 if (itr != _back_prop.end())
217 throw std::runtime_error{"Tried to set a back propagation tensor but another back "
218 "propagation tensor already exists."};
219
220 _back_prop[index] = std::move(tensor);
221 }

◆ setGradientTensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
void onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setGradientTensor ( const ir::OperandIndex index,
std::unique_ptr< GradientTensor tensor 
)
inline

Definition at line 223 of file ITensorRegistry.h.

224 {
225 assert(tensor != nullptr);
226 auto itr = _gradient.find(index);
227 if (itr != _gradient.end())
228 throw std::runtime_error{
229 "Tried to set a gradient tensor but another gradient tensor already exists."};
230
231 _gradient[index] = std::move(tensor);
232 }

◆ setMigrantTensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
bool onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setMigrantTensor ( const ir::OperandIndex ,
IPortableTensor  
)
inlineoverridevirtual

Set the Migrant Tensor which are from other backends.

Returns
true if supported
false if not supported

Reimplemented from onert::backend::ITensorRegistry.

Definition at line 181 of file ITensorRegistry.h.

182 {
183 assert(tensor != nullptr);
184 if (getITensor(index) != nullptr)
185 throw std::runtime_error{
186 "Tried to set a trainable tensor but another tensor already exists."};
187
188 _migrant[index] = tensor;
189 return true;
190 }
ITensor * getITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native and migrant tensors.

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getITensor().

◆ setNonConstTensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
void onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setNonConstTensor ( const ir::OperandIndex index,
std::unique_ptr< Tensor tensor 
)
inline

Definition at line 192 of file ITensorRegistry.h.

193 {
194 assert(tensor != nullptr);
195 if (getITensor(index) != nullptr)
196 throw std::runtime_error{
197 "Tried to set a trainable tensor but another tensor already exists."};
198
199 _non_const[index] = std::move(tensor);
200 }

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getITensor().

◆ setTrainableTensor()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
void onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::setTrainableTensor ( const ir::OperandIndex index,
std::unique_ptr< TrainableTensor tensor 
)
inline

Definition at line 202 of file ITensorRegistry.h.

203 {
204 assert(tensor != nullptr);
205 if (getITensor(index) != nullptr)
206 throw std::runtime_error{
207 "Tried to set a trainable tensor but another tensor already exists."};
208
209 _trainable[index] = std::move(tensor);
210 }

References onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::getITensor().

◆ trainable_tensors()

template<typename Tensor , typename TrainableTensor , typename BackPropTensor , typename GradientTensor >
const ir::OperandIndexMap< std::unique_ptr< TrainableTensor > > & onert::backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor >::trainable_tensors ( )
inline

Definition at line 234 of file ITensorRegistry.h.

235 {
236 return _trainable;
237 }

The documentation for this class was generated from the following file: