ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::basic::train::TrainableTensor Class Reference

#include <TrainableTensor.h>

Collaboration diagram for onert::backend::basic::train::TrainableTensor:

Public Member Functions

 TrainableTensor ()=delete
 
virtual ~TrainableTensor ()=default
 
 TrainableTensor (const ir::OperandInfo &info)
 
void setBuffer (uint8_t *buffer)
 Set the Buffer object. This method is called for static and non-const tensor.
 
uint8_t * buffer () const override
 
std::vector< ITensor * > optVars () override
 Get optimizer variables of this trainable tensor.
 
void appendOptVar (std::unique_ptr< Tensor > opt_var)
 
void setOptVarBuffer (uint8_t *buffer, size_t pos)
 
void fillBuffer (const std::shared_ptr< ir::Data > &data)
 
- Public Member Functions inherited from onert::backend::train::ITrainableTensor
virtual ~ITrainableTensor ()=default
 
 IPortableTensor (const ir::OperandInfo &info)
 
- Public Member Functions inherited from onert::backend::IPortableTensor
 IPortableTensor (const ir::OperandInfo &info)
 
virtual ~IPortableTensor ()
 
const ir::OperandInfoget_info () const
 
const ir::Sparsitysparsity () const
 
size_t total_size () const override final
 
size_t calcOffset (const ir::Coordinates &coords) const override final
 
ir::DataType data_type () const override final
 
float data_scale () const override final
 
int32_t data_zero_point () const override final
 
const std::vector< float > & data_scales () const override final
 
const std::vector< int32_t > & data_zero_points () const override
 
bool is_constant () const override final
 Return true if the tensor is constant.
 
bool is_dynamic () const override final
 Return true if the tensor needs dynamic allocation, meaning that during compile-time the outpus shape cannot be known and the output shape is calculated during kernel execution-time.
 
ir::Shape getShape () const override final
 Get ir::Shape of tensor.
 
bool has_padding () const final
 
void access (const std::function< void(ITensor &tensor)> &fn) final
 
- Public Member Functions inherited from onert::backend::ITensor
virtual ~ITensor ()
 
virtual void deallocBuffer ()
 Dealloc the buffer (only for dynamic tensors)
 
virtual bool is_subtensor () const
 
virtual bool needMemoryMap () const
 
virtual void enqueueWriteBuffer (const void *, bool)
 
virtual void enqueueReadBuffer (void *, bool)
 

Protected Attributes

Tensor _tensor
 
std::vector< std::unique_ptr< Tensor > > _opt_vars
 
- Protected Attributes inherited from onert::backend::IPortableTensor
ir::OperandInfo _info
 

Detailed Description

Definition at line 27 of file TrainableTensor.h.

Constructor & Destructor Documentation

◆ TrainableTensor() [1/2]

onert::backend::basic::train::TrainableTensor::TrainableTensor ( )
delete

◆ ~TrainableTensor()

virtual onert::backend::basic::train::TrainableTensor::~TrainableTensor ( )
virtualdefault

◆ TrainableTensor() [2/2]

onert::backend::basic::train::TrainableTensor::TrainableTensor ( const ir::OperandInfo info)
inline

Definition at line 34 of file TrainableTensor.h.

35 : ITrainableTensor{info}, _tensor{info, nullptr}, _opt_vars{}
36 {
37 // DO NOTHING
38 }
std::vector< std::unique_ptr< Tensor > > _opt_vars
volatile const char info[]

Member Function Documentation

◆ appendOptVar()

void onert::backend::basic::train::TrainableTensor::appendOptVar ( std::unique_ptr< Tensor opt_var)
inline

Definition at line 51 of file TrainableTensor.h.

51{ _opt_vars.emplace_back(std::move(opt_var)); }

References _opt_vars.

◆ buffer()

uint8_t * onert::backend::basic::train::TrainableTensor::buffer ( ) const
inlineoverridevirtual

Implements onert::backend::ITensor.

Definition at line 47 of file TrainableTensor.h.

47{ return _tensor.buffer(); }
uint8_t * buffer() const override
Definition Tensor.h:69

References _tensor, and onert::backend::basic::Tensor::buffer().

Referenced by fillBuffer(), setBuffer(), and setOptVarBuffer().

◆ fillBuffer()

void onert::backend::basic::train::TrainableTensor::fillBuffer ( const std::shared_ptr< ir::Data > &  data)

Definition at line 32 of file TrainableTensor.cc.

33{
34 auto *buffer = _tensor.buffer();
35 assert(buffer);
36 assert(total_size() == data->size());
37 std::memcpy(buffer, data->base(), data->size());
38}
size_t total_size() const override final

References _tensor, onert::backend::basic::Tensor::buffer(), buffer(), and onert::backend::IPortableTensor::total_size().

Referenced by onert::backend::train::BackendContext::gen().

◆ optVars()

std::vector< ITensor * > onert::backend::basic::train::TrainableTensor::optVars ( )
overridevirtual

Get optimizer variables of this trainable tensor.

Returns
Optimizer variables

Implements onert::backend::train::ITrainableTensor.

Definition at line 22 of file TrainableTensor.cc.

23{
24 std::vector<ITensor *> ret;
25 for (auto &&e : _opt_vars)
26 {
27 ret.emplace_back(e.get());
28 }
29 return ret;
30}

References _opt_vars.

◆ setBuffer()

void onert::backend::basic::train::TrainableTensor::setBuffer ( uint8_t *  buffer)
inline

Set the Buffer object. This method is called for static and non-const tensor.

Definition at line 44 of file TrainableTensor.h.

void setBuffer(uint8_t *buffer)
Set the Buffer object. This method is called for static and non-const tensor.
Definition Tensor.h:52

References _tensor, buffer(), and onert::backend::basic::Tensor::setBuffer().

◆ setOptVarBuffer()

void onert::backend::basic::train::TrainableTensor::setOptVarBuffer ( uint8_t *  buffer,
size_t  pos 
)
inline

Definition at line 52 of file TrainableTensor.h.

52{ _opt_vars.at(pos)->setBuffer(buffer); }

References _opt_vars, and buffer().

Field Documentation

◆ _opt_vars

std::vector<std::unique_ptr<Tensor> > onert::backend::basic::train::TrainableTensor::_opt_vars
protected

Definition at line 64 of file TrainableTensor.h.

Referenced by appendOptVar(), optVars(), and setOptVarBuffer().

◆ _tensor

Tensor onert::backend::basic::train::TrainableTensor::_tensor
protected

Definition at line 63 of file TrainableTensor.h.

Referenced by buffer(), fillBuffer(), and setBuffer().


The documentation for this class was generated from the following files: