ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::builtin::train::TensorRegistry Class Reference

#include <TensorRegistry.h>

Collaboration diagram for onert::backend::builtin::train::TensorRegistry:

Public Member Functions

 TensorRegistry ()
 
ITensorgetITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor among native and migrant tensors.
 
ITensorgetNativeITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor among native tensors.
 
IPortableTensorgetPortableTensor (const ir::OperandIndex &index)
 
IOTensorgetNativeIOTensor (const ir::OperandIndex &index)
 
ITensorgetBackPropITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor for back propatation.
 
ITensorgetGradientITensor (const ir::OperandIndex &index) override
 Returns pointer of ITensor for gradient.
 
BackPropTensorgetBackPropTensor (const ir::OperandIndex &index)
 
bool setMigrantTensor (const ir::OperandIndex &index, IPortableTensor *tensor) override
 Set the Migrant Tensor which are from other backends.
 
void iterateTrainableTensors (const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &) const override
 
void setBackPropTensor (const ir::OperandIndex &index, std::unique_ptr< BackPropTensor > tensor)
 
void setGradientTensor (const ir::OperandIndex &index, std::unique_ptr< GradientTensor > tensor)
 
void setNativeIOTensor (ir::OperandIndex index, std::unique_ptr< IOTensor > &&tensor)
 
const ir::OperandIndexMap< std::unique_ptr< IOTensor > > & native_io_tensors ()
 
std::shared_ptr< BaseTensorRegistrybase_reg ()
 
- Public Member Functions inherited from onert::backend::train::ITensorRegistry
virtual void iterateTrainableTensors (const std::function< void(const ir::OperandIndex &, const train::ITrainableTensor *)> &) const =0
 Iterate ITrainableTensors with fn.
 
- Public Member Functions inherited from onert::backend::ITensorRegistry
virtual ~ITensorRegistry ()=default
 Deconstruct itself.
 

Detailed Description

Definition at line 39 of file TensorRegistry.h.

Constructor & Destructor Documentation

◆ TensorRegistry()

onert::backend::builtin::train::TensorRegistry::TensorRegistry ( )
inline

Definition at line 42 of file TensorRegistry.h.

42: _base_reg{new BaseTensorRegistry} {}
backend::train::PortableTensorRegistryTemplate< Tensor, TrainableTensor, BackPropTensor, GradientTensor > BaseTensorRegistry

Member Function Documentation

◆ base_reg()

std::shared_ptr< BaseTensorRegistry > onert::backend::builtin::train::TensorRegistry::base_reg ( )
inline

Definition at line 128 of file TensorRegistry.h.

128{ return _base_reg; }

◆ getBackPropITensor()

ITensor * onert::backend::builtin::train::TensorRegistry::getBackPropITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor for back propatation.

Note
Return tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::train::ITensorRegistry.

Definition at line 76 of file TensorRegistry.h.

77 {
78 return _base_reg->getBackPropTensor(index);
79 }

◆ getBackPropTensor()

BackPropTensor * onert::backend::builtin::train::TensorRegistry::getBackPropTensor ( const ir::OperandIndex index)
inline

Definition at line 86 of file TensorRegistry.h.

87 {
88 return _base_reg->getBackPropTensor(index);
89 }

◆ getGradientITensor()

ITensor * onert::backend::builtin::train::TensorRegistry::getGradientITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor for gradient.

Note
Returned tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::train::ITensorRegistry.

Definition at line 81 of file TensorRegistry.h.

82 {
83 return _base_reg->getGradientTensor(index);
84 }

◆ getITensor()

ITensor * onert::backend::builtin::train::TensorRegistry::getITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor among native and migrant tensors.

Native Tensor is a tensor that is managed by this backend Migrant Tensor is a tensor that is imported from another backend

Note
Return tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::ITensorRegistry.

Definition at line 44 of file TensorRegistry.h.

45 {
46 auto base_tensor = _base_reg->getITensor(index);
47 if (base_tensor)
48 return base_tensor;
49 return getNativeIOTensor(index);
50 }
IOTensor * getNativeIOTensor(const ir::OperandIndex &index)

References getNativeIOTensor().

Referenced by setMigrantTensor(), and setNativeIOTensor().

◆ getNativeIOTensor()

IOTensor * onert::backend::builtin::train::TensorRegistry::getNativeIOTensor ( const ir::OperandIndex index)
inline

Definition at line 68 of file TensorRegistry.h.

69 {
70 auto tensor = _native_io_tensors.find(index);
71 if (tensor != _native_io_tensors.end())
72 return tensor->second.get();
73 return nullptr;
74 }

Referenced by getITensor(), getNativeITensor(), and getPortableTensor().

◆ getNativeITensor()

ITensor * onert::backend::builtin::train::TensorRegistry::getNativeITensor ( const ir::OperandIndex )
inlineoverridevirtual

Returns pointer of ITensor among native tensors.

Unlike getITensor , this function only searches from native tensors

Note
Returned tensor cannot be used longer than dynamic tensor manager

Implements onert::backend::ITensorRegistry.

Definition at line 52 of file TensorRegistry.h.

53 {
54 auto base_tensor = _base_reg->getNativeITensor(index);
55 if (base_tensor)
56 return base_tensor;
57 return getNativeIOTensor(index);
58 }

References getNativeIOTensor().

◆ getPortableTensor()

IPortableTensor * onert::backend::builtin::train::TensorRegistry::getPortableTensor ( const ir::OperandIndex index)
inline

Definition at line 60 of file TensorRegistry.h.

61 {
62 auto base_tensor = _base_reg->getPortableTensor(index);
63 if (base_tensor)
64 return base_tensor;
65 return getNativeIOTensor(index);
66 }

References getNativeIOTensor().

◆ iterateTrainableTensors()

void onert::backend::builtin::train::TensorRegistry::iterateTrainableTensors ( const std::function< void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &  ) const
inlineoverride

Definition at line 99 of file TensorRegistry.h.

102 {
103 // DO NOTHING
104 // Builtin tensor registry does not have trainable tensor.
105 }

◆ native_io_tensors()

const ir::OperandIndexMap< std::unique_ptr< IOTensor > > & onert::backend::builtin::train::TensorRegistry::native_io_tensors ( )
inline

Definition at line 124 of file TensorRegistry.h.

125 {
126 return _native_io_tensors;
127 }

◆ setBackPropTensor()

void onert::backend::builtin::train::TensorRegistry::setBackPropTensor ( const ir::OperandIndex index,
std::unique_ptr< BackPropTensor tensor 
)
inline

Definition at line 107 of file TensorRegistry.h.

108 {
109 _base_reg->setBackPropTensor(index, std::move(tensor));
110 }

◆ setGradientTensor()

void onert::backend::builtin::train::TensorRegistry::setGradientTensor ( const ir::OperandIndex index,
std::unique_ptr< GradientTensor tensor 
)
inline

Definition at line 112 of file TensorRegistry.h.

113 {
114 _base_reg->setGradientTensor(index, std::move(tensor));
115 }

◆ setMigrantTensor()

bool onert::backend::builtin::train::TensorRegistry::setMigrantTensor ( const ir::OperandIndex ,
IPortableTensor  
)
inlineoverridevirtual

Set the Migrant Tensor which are from other backends.

Returns
true if supported
false if not supported

Reimplemented from onert::backend::ITensorRegistry.

Definition at line 91 of file TensorRegistry.h.

92 {
93 assert(tensor);
94 assert(!getITensor(index)); // For the index, tensor is not registered yet
95 _base_reg->setMigrantTensor(index, tensor);
96 return true;
97 }
ITensor * getITensor(const ir::OperandIndex &index) override
Returns pointer of ITensor among native and migrant tensors.

References getITensor().

◆ setNativeIOTensor()

void onert::backend::builtin::train::TensorRegistry::setNativeIOTensor ( ir::OperandIndex  index,
std::unique_ptr< IOTensor > &&  tensor 
)
inline

Definition at line 117 of file TensorRegistry.h.

118 {
119 assert(tensor);
120 assert(!getITensor(index)); // For the index, tensor is not registered yet
121 _native_io_tensors[index] = std::move(tensor);
122 }
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

References getITensor().


The documentation for this class was generated from the following file: