ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
ITensorRegistry.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__
18#define __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__
19
22
24{
25
27{
28public:
35
41
48
55
61 const std::function<void(const ir::OperandIndex &, const train::ITrainableTensor *)> &)
62 const = 0;
63};
64
65} // namespace onert::backend::train
66
68{
69
70template <typename Tensor, typename TrainableTensor, typename BackPropTensor,
71 typename GradientTensor>
73{
74public:
75 using TrainingTensors = std::tuple<TrainableTensor *, GradientTensor *>;
76
77public:
78 ITensor *getITensor(const ir::OperandIndex &index) override
79 {
80 auto _migrant_tensor = _migrant.find(index);
81 if (_migrant_tensor != _migrant.end())
82 return _migrant_tensor->second;
83 return getNativeITensor(index);
84 }
85
87 {
88 ITensor *tensor = getTrainableTensor(index);
89 if (tensor == nullptr)
90 tensor = getNonConstTensor(index);
91 return tensor;
92 }
93
95 {
96 return getBackPropTensor(index);
97 }
98
100 {
101 return getGradientTensor(index);
102 }
103
105 const std::function<void(const ir::OperandIndex &, const train::ITrainableTensor *)> &fn)
106 const override
107 {
108 for (const auto &[index, tensor] : _trainable)
109 fn(index, tensor.get());
110 }
111
113 {
114 auto tensor = _trainable.find(index);
115 if (tensor != _trainable.end())
116 {
117 if (tensor->second)
118 return tensor->second.get();
119 }
120 return getNonConstTensor(index);
121 }
122
124 {
125 auto tensor = _non_const.find(index);
126 if (tensor != _non_const.end())
127 return tensor->second.get();
128 return nullptr;
129 }
130
132 {
133 auto tensor = _trainable.find(index);
134 if (tensor != _trainable.end())
135 return tensor->second.get();
136
137 return nullptr;
138 }
139
141 {
142 auto tensor = _back_prop.find(index);
143 if (tensor != _back_prop.end())
144 return tensor->second.get();
145 return nullptr;
146 }
147
149 {
150 auto tensor = _gradient.find(index);
151 if (tensor != _gradient.end())
152 return tensor->second.get();
153 return nullptr;
154 }
155
157 {
158 auto trainable = getTrainableTensor(index);
159 if (trainable == nullptr)
160 throw std::runtime_error{
161 "Tried to get a trainable tensor but the corresponding tensor does not exist."};
162
163 auto gradient = getGradientTensor(index);
164 if (gradient == nullptr)
165 throw std::runtime_error{
166 "Tried to get a gradient tensor but the corresponding tensor does not exist."};
167
168 return TrainingTensors{std::make_pair(trainable, gradient)};
169 }
170
171 bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override
172 {
173 assert(tensor != nullptr);
174 if (getITensor(index) != nullptr)
175 throw std::runtime_error{
176 "Tried to set a trainable tensor but another tensor already exists."};
177
178 _migrant[index] = tensor;
179 return true;
180 }
181
182 void setNonConstTensor(const ir::OperandIndex &index, std::unique_ptr<Tensor> tensor)
183 {
184 assert(tensor != nullptr);
185 if (getITensor(index) != nullptr)
186 throw std::runtime_error{
187 "Tried to set a trainable tensor but another tensor already exists."};
188
189 _non_const[index] = std::move(tensor);
190 }
191
192 void setTrainableTensor(const ir::OperandIndex &index, std::unique_ptr<TrainableTensor> tensor)
193 {
194 assert(tensor != nullptr);
195 if (getITensor(index) != nullptr)
196 throw std::runtime_error{
197 "Tried to set a trainable tensor but another tensor already exists."};
198
199 _trainable[index] = std::move(tensor);
200 }
201
202 void setBackPropTensor(const ir::OperandIndex &index, std::unique_ptr<BackPropTensor> tensor)
203 {
204 assert(tensor != nullptr);
205 auto itr = _back_prop.find(index);
206 if (itr != _back_prop.end())
207 throw std::runtime_error{"Tried to set a back propagation tensor but another back "
208 "propagation tensor already exists."};
209
210 _back_prop[index] = std::move(tensor);
211 }
212
213 void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr<GradientTensor> tensor)
214 {
215 assert(tensor != nullptr);
216 auto itr = _gradient.find(index);
217 if (itr != _gradient.end())
218 throw std::runtime_error{
219 "Tried to set a gradient tensor but another gradient tensor already exists."};
220
221 _gradient[index] = std::move(tensor);
222 }
223
234
235private:
236 // Native tensors
239
240 // Migrant tensors
242
243 // Tensors for backpropagation
245
246 // Tensors for updating trainable tensors
248};
249
250} // namespace onert::backend::train
251
252#endif // __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__
A tensor class that is portable for other backends.
virtual ITensor * getBackPropITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor for back propatation.
virtual ITensor * getGradientITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor for gradient.
virtual void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &) const =0
Iterate ITrainableTensors with fn.
A tensor class that can be trained.
ITensor * getGradientITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor for gradient.
IPortableTensor * getPortableTensor(const ir::OperandIndex &index)
Tensor * getNonConstTensor(const ir::OperandIndex &index)
void setBackPropTensor(const ir::OperandIndex &index, std::unique_ptr< BackPropTensor > tensor)
void setNonConstTensor(const ir::OperandIndex &index, std::unique_ptr< Tensor > tensor)
std::tuple< TrainableTensor *, GradientTensor * > TrainingTensors
ITensor * getBackPropITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor for back propatation.
void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr< GradientTensor > tensor)
ITensor * getNativeITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native tensors.
ITensor * getITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native and migrant tensors.
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & nonconst_tensors()
const ir::OperandIndexMap< std::unique_ptr< TrainableTensor > > & trainable_tensors()
TrainableTensor * getTrainableTensor(const ir::OperandIndex &index)
GradientTensor * getGradientTensor(const ir::OperandIndex &index)
void setTrainableTensor(const ir::OperandIndex &index, std::unique_ptr< TrainableTensor > tensor)
BackPropTensor * getBackPropTensor(const ir::OperandIndex &index)
bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override
Set the Migrant Tensor which are from other backends.
TrainingTensors getTrainingTensors(const ir::OperandIndex &index)
void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &fn) const override
Iterate ITrainableTensors with fn.
const ir::OperandIndexMap< std::unique_ptr< GradientTensor > > & gradient_tensors()
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & back_prop_tensors()
Tensor BackPropTensor
Definition Tensor.h:43
Tensor GradientTensor
Definition Tensor.h:44
basic::train::TrainableTensor TrainableTensor
Definition Tensor.h:42
std::unordered_map< OperandIndex, T > OperandIndexMap
virtual ITensor * getITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor among native and migrant tensors.
virtual ITensor * getNativeITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor among native tensors.