17#ifndef __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
18#define __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
20#include "../../backend/builtin/Config.h"
21#include "../../backend/builtin/train/TensorRegistry.h"
27#include <unordered_set>
40 for (
const auto &e : backend_contexts)
42 auto tensor_reg = e.second->tensor_registry();
46 std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(tensor_reg);
48 _tensor_regs.insert(tensor_reg);
52 _tensor_regs.insert(tensor_reg);
57 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator
begin()
const
59 return _tensor_regs.cbegin();
57 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator
begin()
const {
…}
61 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator
end()
const
63 return _tensor_regs.cend();
61 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator
end()
const {
…}
68 return _builtin_tensor_reg;
73 for (
const auto &tensor_reg : _tensor_regs)
75 auto tensor = tensor_reg->getITensor(index);
84 for (
const auto &tensor_reg : _tensor_regs)
86 auto tensor = tensor_reg->getBackPropITensor(index);
97 for (
const auto &tensor_reg : _tensor_regs)
98 tensor_reg->iterateTrainableTensors(fn);
102 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>> _tensor_regs;
103 std::shared_ptr<backend::builtin::train::TensorRegistry> _builtin_tensor_reg;
A tensor class that can be trained.
TensorRegistries()=default
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator end() const
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator begin() const
TensorRegistries(const backend::train::TrainableBackendContexts &backend_contexts, bool include_builtin)
backend::ITensor * getITensor(ir::OperandIndex index) const
void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) const
backend::ITensor * getBackPropITensor(ir::OperandIndex index) const
std::shared_ptr< backend::builtin::train::TensorRegistry > getBuiltinTensorRegistry() const
std::unordered_map< const Backend *, std::unique_ptr< TrainableBackendContext > > TrainableBackendContexts