17#ifndef __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
18#define __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
20#include "../../backend/builtin/Config.h"
21#include "../../backend/builtin/train/TensorRegistry.h"
27#include <unordered_set>
44 for (
const auto &e : backend_contexts)
46 auto tensor_reg = e.second->tensor_registry();
50 std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(tensor_reg);
52 _tensor_regs.insert(tensor_reg);
56 _tensor_regs.insert(tensor_reg);
61 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator
begin()
const
63 return _tensor_regs.cbegin();
65 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator
end()
const
67 return _tensor_regs.cend();
72 return _builtin_tensor_reg;
77 for (
const auto &tensor_reg : _tensor_regs)
79 auto tensor = tensor_reg->getITensor(index);
88 for (
const auto &tensor_reg : _tensor_regs)
90 auto tensor = tensor_reg->getBackPropITensor(index);
101 for (
const auto &tensor_reg : _tensor_regs)
102 tensor_reg->iterateTrainableTensors(fn);
106 std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>> _tensor_regs;
107 std::shared_ptr<backend::builtin::train::TensorRegistry> _builtin_tensor_reg;
A tensor class that can be trained.
TensorRegistries()=default
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator end() const
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator begin() const
TensorRegistries(const backend::train::TrainableBackendContexts &backend_contexts, bool include_builtin)
backend::ITensor * getITensor(ir::OperandIndex index) const
void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) const
backend::ITensor * getBackPropITensor(ir::OperandIndex index) const
std::shared_ptr< backend::builtin::train::TensorRegistry > getBuiltinTensorRegistry() const
std::unordered_map< const Backend *, std::unique_ptr< TrainableBackendContext > > TrainableBackendContexts