ONE - On-device Neural Engine
Loading...
Searching...
No Matches
TensorRegistries.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
18#define __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
19
20#include "../../backend/builtin/Config.h"
21#include "../../backend/builtin/train/TensorRegistry.h"
22
25
26#include <memory>
27#include <unordered_set>
28
29namespace onert
30{
31namespace compiler
32{
33namespace train
34{
35
37{
38public:
39 TensorRegistries() = default;
40
42 bool include_builtin)
43 {
44 for (const auto &e : backend_contexts)
45 {
46 auto tensor_reg = e.second->tensor_registry();
47 if (e.first->config()->id() == backend::builtin::Config::ID)
48 {
49 _builtin_tensor_reg =
50 std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(tensor_reg);
51 if (include_builtin)
52 _tensor_regs.insert(tensor_reg);
53 }
54 else
55 {
56 _tensor_regs.insert(tensor_reg);
57 }
58 }
59 }
60
61 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator begin() const
62 {
63 return _tensor_regs.cbegin();
64 }
65 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator end() const
66 {
67 return _tensor_regs.cend();
68 }
69
70 std::shared_ptr<backend::builtin::train::TensorRegistry> getBuiltinTensorRegistry() const
71 {
72 return _builtin_tensor_reg;
73 }
74
76 {
77 for (const auto &tensor_reg : _tensor_regs)
78 {
79 auto tensor = tensor_reg->getITensor(index);
80 if (tensor)
81 return tensor;
82 }
83 return nullptr;
84 }
85
87 {
88 for (const auto &tensor_reg : _tensor_regs)
89 {
90 auto tensor = tensor_reg->getBackPropITensor(index);
91 if (tensor)
92 return tensor;
93 }
94 return nullptr;
95 }
96
98 const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)>
99 &fn) const
100 {
101 for (const auto &tensor_reg : _tensor_regs)
102 tensor_reg->iterateTrainableTensors(fn);
103 }
104
105private:
106 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>> _tensor_regs;
107 std::shared_ptr<backend::builtin::train::TensorRegistry> _builtin_tensor_reg;
108};
109
110} // namespace train
111} // namespace compiler
112} // namespace onert
113
114#endif // __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
static std::string ID
Definition Config.h:34
A tensor class that can be trained.
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator end() const
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator begin() const
TensorRegistries(const backend::train::TrainableBackendContexts &backend_contexts, bool include_builtin)
backend::ITensor * getITensor(ir::OperandIndex index) const
void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) const
backend::ITensor * getBackPropITensor(ir::OperandIndex index) const
std::shared_ptr< backend::builtin::train::TensorRegistry > getBuiltinTensorRegistry() const
std::unordered_map< const Backend *, std::unique_ptr< TrainableBackendContext > > TrainableBackendContexts