ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::builtin::train::TensorRegistry Class Reference

#include <TensorRegistry.h>

Collaboration diagram for onert::backend::builtin::train::TensorRegistry:

Public Member Functions

 TensorRegistry ()
 
ITensorgetITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor among native and migrant tensors.
 
ITensorgetNativeITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor among native tensors.
 
IPortableTensorgetPortableTensor (const ir::OperandIndex &index)
 
IOTensorgetNativeIOTensor (const ir::OperandIndex &index)
 
ITensorgetBackPropITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor for back propatation.
 
ITensorgetGradientITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor for gradient.
 
BackPropTensorgetBackPropTensor (const ir::OperandIndex &index)
 
bool setMigrantTensor (const ir::OperandIndex &index, IPortableTensor *tensor) override
 Set the Migrant Tensor which are from other backends.
 
void iterateTrainableTensors (const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &) const override
 
void setBackPropTensor (const ir::OperandIndex &index, std::unique_ptr< BackPropTensor > tensor)
 
void setGradientTensor (const ir::OperandIndex &index, std::unique_ptr< GradientTensor > tensor)
 
void setNativeIOTensor (ir::OperandIndex index, std::unique_ptr< IOTensor > &&tensor)
 
const ir::OperandIndexMap< std::unique_ptr< IOTensor > > & native_io_tensors ()
 
std::shared_ptr< BaseTensorRegistrybase_reg ()
 
- Public Member Functions inherited from onert::backend::train::ITensorRegistry
virtual void iterateTrainableTensors (const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &) const =0
 Iterate ITrainableTensors with fn.
 
- Public Member Functions inherited from onert::backend::ITensorRegistry
virtual ~ITensorRegistry ()=default
 Deconstruct itself.
 

Detailed Description

Definition at line 33 of file TensorRegistry.h.

Constructor & Destructor Documentation

◆ TensorRegistry()

onert::backend::builtin::train::TensorRegistry::TensorRegistry ( )
inline

Definition at line 36 of file TensorRegistry.h.

36: _base_reg{new BaseTensorRegistry} {}
backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor > BaseTensorRegistry

Member Function Documentation

◆ base_reg()

std::shared_ptr< BaseTensorRegistry > onert::backend::builtin::train::TensorRegistry::base_reg ( )
inline

Definition at line 122 of file TensorRegistry.h.

122{ return _base_reg; }

◆ getBackPropITensor()

ITensor * onert::backend::builtin::train::TensorRegistry::getBackPropITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor for back propatation.

Note
Return tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::train::ITensorRegistry.

Definition at line 70 of file TensorRegistry.h.

71 {
72 return _base_reg->getBackPropTensor(index);
73 }

◆ getBackPropTensor()

BackPropTensor * onert::backend::builtin::train::TensorRegistry::getBackPropTensor ( const ir::OperandIndex index)
inline

Definition at line 80 of file TensorRegistry.h.

81 {
82 return _base_reg->getBackPropTensor(index);
83 }

◆ getGradientITensor()

ITensor * onert::backend::builtin::train::TensorRegistry::getGradientITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor for gradient.

Note
Returned tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::train::ITensorRegistry.

Definition at line 75 of file TensorRegistry.h.

76 {
77 return _base_reg->getGradientTensor(index);
78 }

◆ getITensor()

ITensor * onert::backend::builtin::train::TensorRegistry::getITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor among native and migrant tensors.

Native Tensor is a tensor that is managed by this backend Migrant Tensor is a tensor that is imported from another backend

Note
Return tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::ITensorRegistry.

Definition at line 38 of file TensorRegistry.h.

39 {
40 auto base_tensor = _base_reg->getITensor(index);
41 if (base_tensor)
42 return base_tensor;
43 return getNativeIOTensor(index);
44 }
IOTensor * getNativeIOTensor(const ir::OperandIndex &index)

References getNativeIOTensor().

Referenced by setMigrantTensor(), and setNativeIOTensor().

◆ getNativeIOTensor()

IOTensor * onert::backend::builtin::train::TensorRegistry::getNativeIOTensor ( const ir::OperandIndex index)
inline

Definition at line 62 of file TensorRegistry.h.

63 {
64 auto tensor = _native_io_tensors.find(index);
65 if (tensor != _native_io_tensors.end())
66 return tensor->second.get();
67 return nullptr;
68 }

Referenced by getITensor(), getNativeITensor(), and getPortableTensor().

◆ getNativeITensor()

ITensor * onert::backend::builtin::train::TensorRegistry::getNativeITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor among native tensors.

Unlike getITensor , this function only searches from native tensors

Note
Returned tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::ITensorRegistry.

Definition at line 46 of file TensorRegistry.h.

47 {
48 auto base_tensor = _base_reg->getNativeITensor(index);
49 if (base_tensor)
50 return base_tensor;
51 return getNativeIOTensor(index);
52 }

References getNativeIOTensor().

◆ getPortableTensor()

IPortableTensor * onert::backend::builtin::train::TensorRegistry::getPortableTensor ( const ir::OperandIndex index)
inline

Definition at line 54 of file TensorRegistry.h.

55 {
56 auto base_tensor = _base_reg->getPortableTensor(index);
57 if (base_tensor)
58 return base_tensor;
59 return getNativeIOTensor(index);
60 }

References getNativeIOTensor().

◆ iterateTrainableTensors()

void onert::backend::builtin::train::TensorRegistry::iterateTrainableTensors ( const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &  ) const
inlineoverride

Definition at line 93 of file TensorRegistry.h.

96 {
97 // DO NOTHING
98 // Builtin tensor registry does not have trainable tensor.
99 }

◆ native_io_tensors()

const ir::OperandIndexMap< std::unique_ptr< IOTensor > > & onert::backend::builtin::train::TensorRegistry::native_io_tensors ( )
inline

Definition at line 118 of file TensorRegistry.h.

119 {
120 return _native_io_tensors;
121 }

◆ setBackPropTensor()

void onert::backend::builtin::train::TensorRegistry::setBackPropTensor ( const ir::OperandIndex index,
std::unique_ptr< BackPropTensor tensor 
)
inline

Definition at line 101 of file TensorRegistry.h.

102 {
103 _base_reg->setBackPropTensor(index, std::move(tensor));
104 }

◆ setGradientTensor()

void onert::backend::builtin::train::TensorRegistry::setGradientTensor ( const ir::OperandIndex index,
std::unique_ptr< GradientTensor tensor 
)
inline

Definition at line 106 of file TensorRegistry.h.

107 {
108 _base_reg->setGradientTensor(index, std::move(tensor));
109 }

◆ setMigrantTensor()

bool onert::backend::builtin::train::TensorRegistry::setMigrantTensor ( const ir::OperandIndex ,
IPortableTensor  
)
inlineoverridevirtual

Set the Migrant Tensor which are from other backends.

Returns
true if supported
false if not supported

Reimplemented from onert::backend::ITensorRegistry.

Definition at line 85 of file TensorRegistry.h.

86 {
87 assert(tensor);
88 assert(!getITensor(index)); // For the index, tensor is not registered yet
89 _base_reg->setMigrantTensor(index, tensor);
90 return true;
91 }
ITensor * getITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native and migrant tensors.

References getITensor().

◆ setNativeIOTensor()

void onert::backend::builtin::train::TensorRegistry::setNativeIOTensor ( ir::OperandIndex  index,
std::unique_ptr< IOTensor > &&  tensor 
)
inline

Definition at line 111 of file TensorRegistry.h.

112 {
113 assert(tensor);
114 assert(!getITensor(index)); // For the index, tensor is not registered yet
115 _native_io_tensors[index] = std::move(tensor);
116 }
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

References getITensor().


The documentation for this class was generated from the following file: