17#ifndef __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__
18#define __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__
90 auto _migrant_tensor = _migrant.find(index);
91 if (_migrant_tensor != _migrant.end())
92 return _migrant_tensor->second;
99 if (tensor ==
nullptr)
118 for (
const auto &[index, tensor] : _trainable)
119 fn(index, tensor.get());
124 auto tensor = _trainable.find(index);
125 if (tensor != _trainable.end())
128 return tensor->second.get();
135 auto tensor = _non_const.find(index);
136 if (tensor != _non_const.end())
137 return tensor->second.get();
143 auto tensor = _trainable.find(index);
144 if (tensor != _trainable.end())
145 return tensor->second.get();
152 auto tensor = _back_prop.find(index);
153 if (tensor != _back_prop.end())
154 return tensor->second.get();
160 auto tensor = _gradient.find(index);
161 if (tensor != _gradient.end())
162 return tensor->second.get();
169 if (trainable ==
nullptr)
170 throw std::runtime_error{
171 "Tried to get a trainable tensor but the corresponding tensor does not exist."};
174 if (gradient ==
nullptr)
175 throw std::runtime_error{
176 "Tried to get a gradient tensor but the corresponding tensor does not exist."};
183 assert(tensor !=
nullptr);
185 throw std::runtime_error{
186 "Tried to set a trainable tensor but another tensor already exists."};
188 _migrant[index] = tensor;
194 assert(tensor !=
nullptr);
196 throw std::runtime_error{
197 "Tried to set a trainable tensor but another tensor already exists."};
199 _non_const[index] = std::move(tensor);
204 assert(tensor !=
nullptr);
206 throw std::runtime_error{
207 "Tried to set a trainable tensor but another tensor already exists."};
209 _trainable[index] = std::move(tensor);
214 assert(tensor !=
nullptr);
215 auto itr = _back_prop.find(index);
216 if (itr != _back_prop.end())
217 throw std::runtime_error{
"Tried to set a back propagation tensor but another back "
218 "propagation tensor already exists."};
220 _back_prop[index] = std::move(tensor);
225 assert(tensor !=
nullptr);
226 auto itr = _gradient.find(index);
227 if (itr != _gradient.end())
228 throw std::runtime_error{
229 "Tried to set a gradient tensor but another gradient tensor already exists."};
231 _gradient[index] = std::move(tensor);
A tensor class that is portable for other backends.
virtual ITensor * getBackPropITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor for back propatation.
virtual ITensor * getGradientITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor for gradient.
virtual void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &) const =0
Iterate ITrainableTensors with fn.
A tensor class that can be trained.
ITensor * getGradientITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor for gradient.
IPortableTensor * getPortableTensor(const ir::OperandIndex &index)
Tensor * getNonConstTensor(const ir::OperandIndex &index)
void setBackPropTensor(const ir::OperandIndex &index, std::unique_ptr< BackPropTensor > tensor)
void setNonConstTensor(const ir::OperandIndex &index, std::unique_ptr< Tensor > tensor)
std::tuple< TrainableTensor *, GradientTensor * > TrainingTensors
ITensor * getBackPropITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor for back propatation.
void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr< GradientTensor > tensor)
ITensor * getNativeITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native tensors.
ITensor * getITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native and migrant tensors.
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & nonconst_tensors()
const ir::OperandIndexMap< std::unique_ptr< TrainableTensor > > & trainable_tensors()
TrainableTensor * getTrainableTensor(const ir::OperandIndex &index)
GradientTensor * getGradientTensor(const ir::OperandIndex &index)
void setTrainableTensor(const ir::OperandIndex &index, std::unique_ptr< TrainableTensor > tensor)
BackPropTensor * getBackPropTensor(const ir::OperandIndex &index)
bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override
Set the Migrant Tensor which are from other backends.
TrainingTensors getTrainingTensors(const ir::OperandIndex &index)
void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &fn) const override
Iterate ITrainableTensors with fn.
const ir::OperandIndexMap< std::unique_ptr< GradientTensor > > & gradient_tensors()
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & back_prop_tensors()
basic::train::TrainableTensor TrainableTensor
std::unordered_map< OperandIndex, T > OperandIndexMap
virtual ITensor * getITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor among native and migrant tensors.
virtual ITensor * getNativeITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor among native tensors.