ONE - On-device Neural Engine
Loading...
Searching...
No Matches
ITensorRegistry.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__
18#define __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__
19
22
23namespace onert
24{
25namespace backend
26{
27namespace train
28{
29
31{
32public:
39
45
52
59
65 const std::function<void(const ir::OperandIndex &, const train::ITrainableTensor *)> &)
66 const = 0;
67};
68
69} // namespace train
70} // namespace backend
71} // namespace onert
72
73namespace onert
74{
75namespace backend
76{
77namespace train
78{
79
80template <typename Tensor, typename TrainableTensor, typename BackPropTensor,
81 typename GradientTensor>
83{
84public:
85 using TrainingTensors = std::tuple<TrainableTensor *, GradientTensor *>;
86
87public:
88 ITensor *getITensor(const ir::OperandIndex &index) override
89 {
90 auto _migrant_tensor = _migrant.find(index);
91 if (_migrant_tensor != _migrant.end())
92 return _migrant_tensor->second;
93 return getNativeITensor(index);
94 }
95
97 {
98 ITensor *tensor = getTrainableTensor(index);
99 if (tensor == nullptr)
100 tensor = getNonConstTensor(index);
101 return tensor;
102 }
103
105 {
106 return getBackPropTensor(index);
107 }
108
110 {
111 return getGradientTensor(index);
112 }
113
115 const std::function<void(const ir::OperandIndex &, const train::ITrainableTensor *)> &fn)
116 const override
117 {
118 for (const auto &[index, tensor] : _trainable)
119 fn(index, tensor.get());
120 }
121
123 {
124 auto tensor = _trainable.find(index);
125 if (tensor != _trainable.end())
126 {
127 if (tensor->second)
128 return tensor->second.get();
129 }
130 return getNonConstTensor(index);
131 }
132
134 {
135 auto tensor = _non_const.find(index);
136 if (tensor != _non_const.end())
137 return tensor->second.get();
138 return nullptr;
139 }
140
142 {
143 auto tensor = _trainable.find(index);
144 if (tensor != _trainable.end())
145 return tensor->second.get();
146
147 return nullptr;
148 }
149
151 {
152 auto tensor = _back_prop.find(index);
153 if (tensor != _back_prop.end())
154 return tensor->second.get();
155 return nullptr;
156 }
157
159 {
160 auto tensor = _gradient.find(index);
161 if (tensor != _gradient.end())
162 return tensor->second.get();
163 return nullptr;
164 }
165
167 {
168 auto trainable = getTrainableTensor(index);
169 if (trainable == nullptr)
170 throw std::runtime_error{
171 "Tried to get a trainable tensor but the corresponding tensor does not exist."};
172
173 auto gradient = getGradientTensor(index);
174 if (gradient == nullptr)
175 throw std::runtime_error{
176 "Tried to get a gradient tensor but the corresponding tensor does not exist."};
177
178 return TrainingTensors{std::make_pair(trainable, gradient)};
179 }
180
181 bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override
182 {
183 assert(tensor != nullptr);
184 if (getITensor(index) != nullptr)
185 throw std::runtime_error{
186 "Tried to set a trainable tensor but another tensor already exists."};
187
188 _migrant[index] = tensor;
189 return true;
190 }
191
192 void setNonConstTensor(const ir::OperandIndex &index, std::unique_ptr<Tensor> tensor)
193 {
194 assert(tensor != nullptr);
195 if (getITensor(index) != nullptr)
196 throw std::runtime_error{
197 "Tried to set a trainable tensor but another tensor already exists."};
198
199 _non_const[index] = std::move(tensor);
200 }
201
202 void setTrainableTensor(const ir::OperandIndex &index, std::unique_ptr<TrainableTensor> tensor)
203 {
204 assert(tensor != nullptr);
205 if (getITensor(index) != nullptr)
206 throw std::runtime_error{
207 "Tried to set a trainable tensor but another tensor already exists."};
208
209 _trainable[index] = std::move(tensor);
210 }
211
212 void setBackPropTensor(const ir::OperandIndex &index, std::unique_ptr<BackPropTensor> tensor)
213 {
214 assert(tensor != nullptr);
215 auto itr = _back_prop.find(index);
216 if (itr != _back_prop.end())
217 throw std::runtime_error{"Tried to set a back propagation tensor but another back "
218 "propagation tensor already exists."};
219
220 _back_prop[index] = std::move(tensor);
221 }
222
223 void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr<GradientTensor> tensor)
224 {
225 assert(tensor != nullptr);
226 auto itr = _gradient.find(index);
227 if (itr != _gradient.end())
228 throw std::runtime_error{
229 "Tried to set a gradient tensor but another gradient tensor already exists."};
230
231 _gradient[index] = std::move(tensor);
232 }
233
244
245private:
246 // Native tensors
249
250 // Migrant tensors
252
253 // Tensors for backpropagation
255
256 // Tensors for updating trainable tensors
258};
259
260} // namespace train
261} // namespace backend
262} // namespace onert
263
264#endif // __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__
A tensor class that is portable for other backends.
virtual ITensor * getBackPropITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor for back propatation.
virtual ITensor * getGradientITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor for gradient.
virtual void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &) const =0
Iterate ITrainableTensors with fn.
A tensor class that can be trained.
ITensor * getGradientITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor for gradient.
IPortableTensor * getPortableTensor(const ir::OperandIndex &index)
Tensor * getNonConstTensor(const ir::OperandIndex &index)
void setBackPropTensor(const ir::OperandIndex &index, std::unique_ptr< BackPropTensor > tensor)
void setNonConstTensor(const ir::OperandIndex &index, std::unique_ptr< Tensor > tensor)
std::tuple< TrainableTensor *, GradientTensor * > TrainingTensors
ITensor * getBackPropITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor for back propatation.
void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr< GradientTensor > tensor)
ITensor * getNativeITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native tensors.
ITensor * getITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native and migrant tensors.
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & nonconst_tensors()
const ir::OperandIndexMap< std::unique_ptr< TrainableTensor > > & trainable_tensors()
TrainableTensor * getTrainableTensor(const ir::OperandIndex &index)
GradientTensor * getGradientTensor(const ir::OperandIndex &index)
void setTrainableTensor(const ir::OperandIndex &index, std::unique_ptr< TrainableTensor > tensor)
BackPropTensor * getBackPropTensor(const ir::OperandIndex &index)
bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override
Set the Migrant Tensor which are from other backends.
TrainingTensors getTrainingTensors(const ir::OperandIndex &index)
void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &fn) const override
Iterate ITrainableTensors with fn.
const ir::OperandIndexMap< std::unique_ptr< GradientTensor > > & gradient_tensors()
const ir::OperandIndexMap< std::unique_ptr< Tensor > > & back_prop_tensors()
Tensor BackPropTensor
Definition Tensor.h:47
Tensor GradientTensor
Definition Tensor.h:48
basic::train::TrainableTensor TrainableTensor
Definition Tensor.h:46
std::unordered_map< OperandIndex, T > OperandIndexMap
virtual ITensor * getITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor among native and migrant tensors.
virtual ITensor * getNativeITensor(const ir::OperandIndex &)=0
Returns pointer of ITensor among native tensors.