ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::compiler::train::TensorRegistries Class Reference

#include <TensorRegistries.h>

Public Member Functions

 TensorRegistries ()=default
 
 TensorRegistries (const backend::train::TrainableBackendContexts &backend_contexts, bool include_builtin)
 
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator begin () const
 
std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator end () const
 
std::shared_ptr< backend::builtin::train::TensorRegistrygetBuiltinTensorRegistry () const
 
backend::ITensorgetITensor (ir::OperandIndex index) const
 
backend::ITensorgetBackPropITensor (ir::OperandIndex index) const
 
void iterateTrainableTensors (const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) const
 

Detailed Description

Definition at line 32 of file TensorRegistries.h.

Constructor & Destructor Documentation

◆ TensorRegistries() [1/2]

onert::compiler::train::TensorRegistries::TensorRegistries ( )
default

◆ TensorRegistries() [2/2]

onert::compiler::train::TensorRegistries::TensorRegistries ( const backend::train::TrainableBackendContexts backend_contexts,
bool  include_builtin 
)
inline

Definition at line 37 of file TensorRegistries.h.

39 {
40 for (const auto &e : backend_contexts)
41 {
42 auto tensor_reg = e.second->tensor_registry();
43 if (e.first->config()->id() == backend::builtin::Config::ID)
44 {
45 _builtin_tensor_reg =
46 std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(tensor_reg);
47 if (include_builtin)
48 _tensor_regs.insert(tensor_reg);
49 }
50 else
51 {
52 _tensor_regs.insert(tensor_reg);
53 }
54 }
55 }
static std::string ID
Definition Config.h:30

References onert::backend::builtin::Config::ID.

Member Function Documentation

◆ begin()

std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator onert::compiler::train::TensorRegistries::begin ( ) const
inline

Definition at line 57 of file TensorRegistries.h.

58 {
59 return _tensor_regs.cbegin();
60 }

◆ end()

std::unordered_set< std::shared_ptr< backend::train::ITensorRegistry > >::const_iterator onert::compiler::train::TensorRegistries::end ( ) const
inline

Definition at line 61 of file TensorRegistries.h.

62 {
63 return _tensor_regs.cend();
64 }

◆ getBackPropITensor()

backend::ITensor * onert::compiler::train::TensorRegistries::getBackPropITensor ( ir::OperandIndex  index) const
inline

Definition at line 82 of file TensorRegistries.h.

83 {
84 for (const auto &tensor_reg : _tensor_regs)
85 {
86 auto tensor = tensor_reg->getBackPropITensor(index);
87 if (tensor)
88 return tensor;
89 }
90 return nullptr;
91 }

◆ getBuiltinTensorRegistry()

std::shared_ptr< backend::builtin::train::TensorRegistry > onert::compiler::train::TensorRegistries::getBuiltinTensorRegistry ( ) const
inline

Definition at line 66 of file TensorRegistries.h.

67 {
68 return _builtin_tensor_reg;
69 }

◆ getITensor()

backend::ITensor * onert::compiler::train::TensorRegistries::getITensor ( ir::OperandIndex  index) const
inline

Definition at line 71 of file TensorRegistries.h.

72 {
73 for (const auto &tensor_reg : _tensor_regs)
74 {
75 auto tensor = tensor_reg->getITensor(index);
76 if (tensor)
77 return tensor;
78 }
79 return nullptr;
80 }

Referenced by onert::exec::train::TrainableExecutor::getLoss(), and onert::exec::train::TrainableExecutor::TrainableExecutor().

◆ iterateTrainableTensors()

void onert::compiler::train::TensorRegistries::iterateTrainableTensors ( const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &  fn) const
inline

Definition at line 93 of file TensorRegistries.h.

96 {
97 for (const auto &tensor_reg : _tensor_regs)
98 tensor_reg->iterateTrainableTensors(fn);
99 }
void iterateTrainableTensors(const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn) const

Referenced by onert::exec::train::TrainableExecutor::iterateTrainableTensors().


The documentation for this class was generated from the following file: