ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::compiler::train::TensorRegistries Class Reference

#include <TensorRegistries.h>

Public Member Functions

 TensorRegistries ()=default
 
 TensorRegistries (const backend::train::TrainableBackendContexts &backend_contexts, bool include_builtin)
 
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator begin () const
 
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator end () const
 
std::shared_ptr< backend::builtin::train::TensorRegistrygetBuiltinTensorRegistry () const
 
backend::ITensorgetITensor (ir::OperandIndex index) const
 
backend::ITensorgetBackPropITensor (ir::OperandIndex index) const
 
void iterateTrainableTensors (const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) const
 

Detailed Description

Definition at line 36 of file TensorRegistries.h.

Constructor & Destructor Documentation

◆ TensorRegistries() [1/2]

onert::compiler::train::TensorRegistries::TensorRegistries ( )
default

◆ TensorRegistries() [2/2]

onert::compiler::train::TensorRegistries::TensorRegistries ( const backend::train::TrainableBackendContexts backend_contexts,
bool  include_builtin 
)
inline

Definition at line 41 of file TensorRegistries.h.

43 {
44 for (const auto &e : backend_contexts)
45 {
46 auto tensor_reg = e.second->tensor_registry();
47 if (e.first->config()->id() == backend::builtin::Config::ID)
48 {
49 _builtin_tensor_reg =
50 std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(tensor_reg);
51 if (include_builtin)
52 _tensor_regs.insert(tensor_reg);
53 }
54 else
55 {
56 _tensor_regs.insert(tensor_reg);
57 }
58 }
59 }
static std::string ID
Definition Config.h:34

References onert::backend::builtin::Config::ID.

Member Function Documentation

◆ begin()

std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator onert::compiler::train::TensorRegistries::begin ( ) const
inline

Definition at line 61 of file TensorRegistries.h.

62 {
63 return _tensor_regs.cbegin();
64 }

◆ end()

std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator onert::compiler::train::TensorRegistries::end ( ) const
inline

Definition at line 65 of file TensorRegistries.h.

66 {
67 return _tensor_regs.cend();
68 }

◆ getBackPropITensor()

backend::ITensor * onert::compiler::train::TensorRegistries::getBackPropITensor ( ir::OperandIndex  index) const
inline

Definition at line 86 of file TensorRegistries.h.

87 {
88 for (const auto &tensor_reg : _tensor_regs)
89 {
90 auto tensor = tensor_reg->getBackPropITensor(index);
91 if (tensor)
92 return tensor;
93 }
94 return nullptr;
95 }

◆ getBuiltinTensorRegistry()

std::shared_ptr< backend::builtin::train::TensorRegistry > onert::compiler::train::TensorRegistries::getBuiltinTensorRegistry ( ) const
inline

Definition at line 70 of file TensorRegistries.h.

71 {
72 return _builtin_tensor_reg;
73 }

◆ getITensor()

backend::ITensor * onert::compiler::train::TensorRegistries::getITensor ( ir::OperandIndex  index) const
inline

Definition at line 75 of file TensorRegistries.h.

76 {
77 for (const auto &tensor_reg : _tensor_regs)
78 {
79 auto tensor = tensor_reg->getITensor(index);
80 if (tensor)
81 return tensor;
82 }
83 return nullptr;
84 }

Referenced by onert::exec::train::TrainableExecutor::getLoss(), and onert::exec::train::TrainableExecutor::TrainableExecutor().

◆ iterateTrainableTensors()

void onert::compiler::train::TensorRegistries::iterateTrainableTensors ( const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &  fn) const
inline

Definition at line 97 of file TensorRegistries.h.

100 {
101 for (const auto &tensor_reg : _tensor_regs)
102 tensor_reg->iterateTrainableTensors(fn);
103 }
void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) const

Referenced by onert::exec::train::TrainableExecutor::iterateTrainableTensors().


The documentation for this class was generated from the following file: