ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
TensorRegistries.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
18#define __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
19
20#include "../../backend/builtin/Config.h"
21#include "../../backend/builtin/train/TensorRegistry.h"
22
25
26#include <memory>
27#include <unordered_set>
28
30{
31
33{
34public:
35 TensorRegistries() = default;
36
38 bool include_builtin)
39 {
40 for (const auto &e : backend_contexts)
41 {
42 auto tensor_reg = e.second->tensor_registry();
43 if (e.first->config()->id() == backend::builtin::Config::ID)
44 {
45 _builtin_tensor_reg =
46 std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(tensor_reg);
47 if (include_builtin)
48 _tensor_regs.insert(tensor_reg);
49 }
50 else
51 {
52 _tensor_regs.insert(tensor_reg);
53 }
54 }
55 }
56
57 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator begin() const
58 {
59 return _tensor_regs.cbegin();
60 }
61 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator end() const
62 {
63 return _tensor_regs.cend();
64 }
65
66 std::shared_ptr<backend::builtin::train::TensorRegistry> getBuiltinTensorRegistry() const
67 {
68 return _builtin_tensor_reg;
69 }
70
72 {
73 for (const auto &tensor_reg : _tensor_regs)
74 {
75 auto tensor = tensor_reg->getITensor(index);
76 if (tensor)
77 return tensor;
78 }
79 return nullptr;
80 }
81
83 {
84 for (const auto &tensor_reg : _tensor_regs)
85 {
86 auto tensor = tensor_reg->getBackPropITensor(index);
87 if (tensor)
88 return tensor;
89 }
90 return nullptr;
91 }
92
94 const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)>
95 &fn) const
96 {
97 for (const auto &tensor_reg : _tensor_regs)
98 tensor_reg->iterateTrainableTensors(fn);
99 }
100
101private:
102 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>> _tensor_regs;
103 std::shared_ptr<backend::builtin::train::TensorRegistry> _builtin_tensor_reg;
104};
105
106} // namespace onert::compiler::train
107
108#endif // __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
static std::string ID
Definition Config.h:30
A tensor class that can be trained.
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator end() const
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator begin() const
TensorRegistries(const backend::train::TrainableBackendContexts &backend_contexts, bool include_builtin)
backend::ITensor * getITensor(ir::OperandIndex index) const
void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) const
backend::ITensor * getBackPropITensor(ir::OperandIndex index) const
std::shared_ptr< backend::builtin::train::TensorRegistry > getBuiltinTensorRegistry() const
std::unordered_map< const Backend *, std::unique_ptr< TrainableBackendContext > > TrainableBackendContexts