17#ifndef __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__
18#define __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__
80 auto _migrant_tensor = _migrant.find(index);
81 if (_migrant_tensor != _migrant.end())
82 return _migrant_tensor->second;
89 if (tensor ==
nullptr)
108 for (
const auto &[index, tensor] : _trainable)
109 fn(index, tensor.get());
114 auto tensor = _trainable.find(index);
115 if (tensor != _trainable.end())
118 return tensor->second.get();
125 auto tensor = _non_const.find(index);
126 if (tensor != _non_const.end())
127 return tensor->second.get();
133 auto tensor = _trainable.find(index);
134 if (tensor != _trainable.end())
135 return tensor->second.get();
142 auto tensor = _back_prop.find(index);
143 if (tensor != _back_prop.end())
144 return tensor->second.get();
150 auto tensor = _gradient.find(index);
151 if (tensor != _gradient.end())
152 return tensor->second.get();
159 if (trainable ==
nullptr)
160 throw std::runtime_error{
161 "Tried to get a trainable tensor but the corresponding tensor does not exist."};
164 if (gradient ==
nullptr)
165 throw std::runtime_error{
166 "Tried to get a gradient tensor but the corresponding tensor does not exist."};
173 assert(tensor !=
nullptr);
175 throw std::runtime_error{
176 "Tried to set a trainable tensor but another tensor already exists."};
178 _migrant[index] = tensor;
184 assert(tensor !=
nullptr);
186 throw std::runtime_error{
187 "Tried to set a trainable tensor but another tensor already exists."};
189 _non_const[index] = std::move(tensor);
194 assert(tensor !=
nullptr);
196 throw std::runtime_error{
197 "Tried to set a trainable tensor but another tensor already exists."};
199 _trainable[index] = std::move(tensor);
204 assert(tensor !=
nullptr);
205 auto itr = _back_prop.find(index);
206 if (itr != _back_prop.end())
207 throw std::runtime_error{
"Tried to set a back propagation tensor but another back "
208 "propagation tensor already exists."};
210 _back_prop[index] = std::move(tensor);
215 assert(tensor !=
nullptr);
216 auto itr = _gradient.find(index);
217 if (itr != _gradient.end())
218 throw std::runtime_error{
219 "Tried to set a gradient tensor but another gradient tensor already exists."};
221 _gradient[index] = std::move(tensor);
A tensor class that is portable for other backends.
virtual ITensor * getBackPropITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor for back propatation.
virtual ITensor * getGradientITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor for gradient.
virtual void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &) const =0
Iterate ITrainableTensors with fn.
A tensor class that can be trained.
ITensor * getGradientITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor for gradient.
IPortableTensor * getPortableTensor(const ir::OperandIndex &index)
Tensor * getNonConstTensor(const ir::OperandIndex &index)
void setBackPropTensor(const ir::OperandIndex &index, std::unique_ptr< BackPropTensor > tensor)
void setNonConstTensor(const ir::OperandIndex &index, std::unique_ptr< Tensor > tensor)
std::tuple< TrainableTensor *, GradientTensor * > TrainingTensors
ITensor * getBackPropITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor for back propatation.
void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr< GradientTensor > tensor)
ITensor * getNativeITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native tensors.
ITensor * getITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native and migrant tensors.
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & nonconst_tensors()
const ir::OperandIndexMap< std::unique_ptr< TrainableTensor > > & trainable_tensors()
TrainableTensor * getTrainableTensor(const ir::OperandIndex &index)
GradientTensor * getGradientTensor(const ir::OperandIndex &index)
void setTrainableTensor(const ir::OperandIndex &index, std::unique_ptr< TrainableTensor > tensor)
BackPropTensor * getBackPropTensor(const ir::OperandIndex &index)
bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override
Set the Migrant Tensor which are from other backends.
TrainingTensors getTrainingTensors(const ir::OperandIndex &index)
void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &fn) const override
Iterate ITrainableTensors with fn.
const ir::OperandIndexMap< std::unique_ptr< GradientTensor > > & gradient_tensors()
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & back_prop_tensors()
basic::train::TrainableTensor TrainableTensor
std::unordered_map< OperandIndex, T > OperandIndexMap
virtual ITensor * getITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor among native and migrant tensors.
virtual ITensor * getNativeITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor among native tensors.