ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::train::BackendContext Class Reference

#include <BackendContext.h>

Collaboration diagram for onert::backend::train::BackendContext:

Public Member Functions

 BackendContext (const ITrainableBackend *backend, std::unique_ptr< TrainableContextData > &&tdata, std::shared_ptr< backend::train::ITensorRegistry > tensor_registry=nullptr, std::shared_ptr< TensorBuilder > tensor_builder=nullptr, std::unique_ptr< exec::train::optimizer::Optimizer > optimizer=nullptr, std::shared_ptr< KernelGenerator > kernel_gen=nullptr)
 
 BackendContext (const BackendContext &)=delete
 
 ~BackendContext ()=default
 
BackendContextoperator= (const BackendContext &)=delete
 
FunctionMap gen () override
 
std::shared_ptr< ExternalContextexternal_context ()
 
const exec::train::optimizer::Optimizeroptimizer () const
 
- Public Member Functions inherited from onert::backend::train::TrainableBackendContext
 TrainableBackendContext (const ITrainableBackend *backend, std::unique_ptr< TrainableContextData > &&tdata, std::shared_ptr< ITensorRegistry > tensor_registry=nullptr)
 
virtual ~TrainableBackendContext ()=default
 
const ir::train::TrainableGraphtrainable_graph () const
 
const TrainableContextDatadata () const
 
const ITrainableBackendbackend () const
 
const util::Set< ir::OperandIndex > & external_operands () const
 
const ir::OperandIndexMap< ir::Layout > & operand_layouts () const
 
std::shared_ptr< ITensorRegistrytensor_registry ()
 

Data Fields

std::shared_ptr< KernelGeneratorkernel_gen
 

Additional Inherited Members

- Protected Attributes inherited from onert::backend::train::TrainableBackendContext
std::unique_ptr< TrainableContextData_tdata
 
std::shared_ptr< ITensorRegistry_tensor_registry
 

Detailed Description

Definition at line 51 of file BackendContext.h.

Constructor & Destructor Documentation

◆ BackendContext() [1/2]

onert::backend::train::BackendContext::BackendContext ( const ITrainableBackend backend,
std::unique_ptr< TrainableContextData > &&  tdata,
std::shared_ptr< backend::train::ITensorRegistry tensor_registry = nullptr,
std::shared_ptr< TensorBuilder tensor_builder = nullptr,
std::unique_ptr< exec::train::optimizer::Optimizer optimizer = nullptr,
std::shared_ptr< KernelGenerator kernel_gen = nullptr 
)
inline

Definition at line 54 of file BackendContext.h.

60 kernel_gen{kernel_gen}, _external_context(new ExternalContext),
61 _tensor_builder{tensor_builder}, _optimizer{std::move(optimizer)}
62 {
63 }
std::shared_ptr< KernelGenerator > kernel_gen
const exec::train::optimizer::Optimizer * optimizer() const
std::shared_ptr< ITensorRegistry > tensor_registry()
cpu::ExternalContext ExternalContext

◆ BackendContext() [2/2]

onert::backend::train::BackendContext::BackendContext ( const BackendContext )
delete

◆ ~BackendContext()

onert::backend::train::BackendContext::~BackendContext ( )
default

Member Function Documentation

◆ external_context()

std::shared_ptr< ExternalContext > onert::backend::train::BackendContext::external_context ( )
inline

Definition at line 79 of file BackendContext.h.

79{ return _external_context; }

◆ gen()

FunctionMap onert::backend::train::BackendContext::gen ( )
overridevirtual

Implements onert::backend::train::TrainableBackendContext.

Definition at line 139 of file BackendContext.cc.

140{
141 planForwardTensors();
142 planBackwardTensors();
143
144 _tensor_builder->allocate();
145 _tensor_builder->allocateBackward();
146
147 auto fn_map = generateFunctionMap();
148
149 // Initialize TrainableTensors
151 [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
152 if (external_operands().contains(ind) || !operand.isConstant())
153 return;
154
155 auto tensor = tensor_registry()->getNativeITensor(ind);
156 assert(tensor != nullptr);
157
158 VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl;
159
160 auto data = operand.shareData();
161 assert(data && data->base());
162 auto trainable_tensor = dynamic_cast<TrainableTensor *>(tensor);
163
164 if (trainable_tensor == nullptr)
165 throw std::runtime_error{"This tensor is not trainable tensor"};
166
167 trainable_tensor->fillBuffer(data);
168 });
169
170 // NOTE For memory optimization, we want to free some operand data
171 const_cast<ir::train::TrainableGraph &>(*_tdata->tgraph)
172 .operands()
173 .iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });
174
175 // TODO Enable
176 // for (auto &&it : ret)
177 // {
178 // auto &fn_seq = it.second;
179 // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
180 // }
181
182 // NOTE: Since LayerScopeTensors is defined in each kernel(layer),
183 // It should be planned and allocated after the kernels generated.
184 planLayerScopeTensors(fn_map);
185 _tensor_builder->allocateLayerScope();
186
187 return fn_map;
188}
void fillBuffer(const std::shared_ptr< ir::Data > &data)
const ir::train::TrainableGraph * trainable_graph() const
const util::Set< ir::OperandIndex > & external_operands() const
std::unique_ptr< TrainableContextData > _tdata
const Operands & operands() const override
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
#define VERBOSE(name, lv)
Definition Log.h:71
basic::train::TrainableTensor TrainableTensor
Definition Tensor.h:46
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:35

References onert::backend::train::TrainableBackendContext::_tdata, onert::backend::train::TrainableBackendContext::data(), onert::backend::train::TrainableBackendContext::external_operands(), onert::backend::basic::train::TrainableTensor::fillBuffer(), onert::ir::Operand::isConstant(), onert::util::ObjectManager< Index, Object >::iterate(), onert::ir::train::TrainableGraph::operands(), onert::ir::Operand::shareData(), onert::backend::train::TrainableBackendContext::tensor_registry(), onert::backend::train::TrainableBackendContext::trainable_graph(), and VERBOSE.

◆ operator=()

BackendContext & onert::backend::train::BackendContext::operator= ( const BackendContext )
delete

◆ optimizer()

const exec::train::optimizer::Optimizer * onert::backend::train::BackendContext::optimizer ( ) const
inline

Definition at line 81 of file BackendContext.h.

81{ return _optimizer.get(); }

Field Documentation

◆ kernel_gen

std::shared_ptr<KernelGenerator> onert::backend::train::BackendContext::kernel_gen

Definition at line 88 of file BackendContext.h.


The documentation for this class was generated from the following files: