ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::train::BackendContext Class Reference

#include <BackendContext.h>

Collaboration diagram for onert::backend::train::BackendContext:

Public Member Functions

 BackendContext (const ITrainableBackend *backend, std::unique_ptr< TrainableContextData > &&tdata, std::shared_ptr< backend::train::ITensorRegistry > tensor_registry=nullptr, std::shared_ptr< TensorBuilder > tensor_builder=nullptr, std::unique_ptr< exec::train::optimizer::Optimizer > optimizer=nullptr, std::shared_ptr< KernelGenerator > kernel_gen=nullptr)
 
 BackendContext (const BackendContext &)=delete
 
 ~BackendContext ()=default
 
BackendContextoperator= (const BackendContext &)=delete
 
FunctionMap gen () override
 
std::shared_ptr< ExternalContextexternal_context ()
 
const exec::train::optimizer::Optimizeroptimizer () const
 
- Public Member Functions inherited from onert::backend::train::TrainableBackendContext
 TrainableBackendContext (const ITrainableBackend *backend, std::unique_ptr< TrainableContextData > &&tdata, std::shared_ptr< ITensorRegistry > tensor_registry=nullptr)
 
virtual ~TrainableBackendContext ()=default
 
const ir::train::TrainableGraphtrainable_graph () const
 
const TrainableContextDatadata () const
 
const ITrainableBackendbackend () const
 
const util::Set< ir::OperandIndex > & external_operands () const
 
const ir::OperandIndexMap< ir::Layout > & operand_layouts () const
 
std::shared_ptr< ITensorRegistrytensor_registry ()
 

Data Fields

std::shared_ptr< KernelGeneratorkernel_gen
 

Additional Inherited Members

- Protected Attributes inherited from onert::backend::train::TrainableBackendContext
std::unique_ptr< TrainableContextData_tdata
 
std::shared_ptr< ITensorRegistry_tensor_registry
 

Detailed Description

Definition at line 47 of file BackendContext.h.

Constructor & Destructor Documentation

◆ BackendContext() [1/2]

onert::backend::train::BackendContext::BackendContext ( const ITrainableBackend backend,
std::unique_ptr< TrainableContextData > &&  tdata,
std::shared_ptr< backend::train::ITensorRegistry tensor_registry = nullptr,
std::shared_ptr< TensorBuilder tensor_builder = nullptr,
std::unique_ptr< exec::train::optimizer::Optimizer optimizer = nullptr,
std::shared_ptr< KernelGenerator kernel_gen = nullptr 
)
inline

Definition at line 50 of file BackendContext.h.

56 kernel_gen{kernel_gen}, _external_context(new ExternalContext),
57 _tensor_builder{tensor_builder}, _optimizer{std::move(optimizer)}
58 {
59 }
std::shared_ptr< KernelGenerator > kernel_gen
std::shared_ptr< ITensorRegistry > tensor_registry()
cpu::ExternalContext ExternalContext

◆ BackendContext() [2/2]

onert::backend::train::BackendContext::BackendContext ( const BackendContext )
delete

◆ ~BackendContext()

onert::backend::train::BackendContext::~BackendContext ( )
default

Member Function Documentation

◆ external_context()

std::shared_ptr< ExternalContext > onert::backend::train::BackendContext::external_context ( )
inline

Definition at line 75 of file BackendContext.h.

75{ return _external_context; }

◆ gen()

FunctionMap onert::backend::train::BackendContext::gen ( )
overridevirtual

Implements onert::backend::train::TrainableBackendContext.

Definition at line 135 of file BackendContext.cc.

136{
137 planForwardTensors();
138 planBackwardTensors();
139
140 _tensor_builder->allocate();
141 _tensor_builder->allocateBackward();
142
143 auto fn_map = generateFunctionMap();
144
145 // Initialize TrainableTensors
147 [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
148 if (external_operands().contains(ind) || !operand.isConstant())
149 return;
150
151 auto tensor = tensor_registry()->getNativeITensor(ind);
152 assert(tensor != nullptr);
153
154 VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl;
155
156 auto data = operand.shareData();
157 assert(data && data->base());
158 auto trainable_tensor = dynamic_cast<TrainableTensor *>(tensor);
159
160 if (trainable_tensor == nullptr)
161 throw std::runtime_error{"This tensor is not trainable tensor"};
162
163 trainable_tensor->fillBuffer(data);
164 });
165
166 // NOTE For memory optimization, we want to free some operand data
167 const_cast<ir::train::TrainableGraph &>(*_tdata->tgraph)
168 .operands()
169 .iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });
170
171 // TODO Enable
172 // for (auto &&it : ret)
173 // {
174 // auto &fn_seq = it.second;
175 // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
176 // }
177
178 // NOTE: Since LayerScopeTensors is defined in each kernel(layer),
179 // It should be planned and allocated after the kernels generated.
180 planLayerScopeTensors(fn_map);
181 _tensor_builder->allocateLayerScope();
182
183 return fn_map;
184}
void fillBuffer(const std::shared_ptr< ir::Data > &data)
const ir::train::TrainableGraph * trainable_graph() const
const util::Set< ir::OperandIndex > & external_operands() const
std::unique_ptr< TrainableContextData > _tdata
const Operands & operands() const override
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
#define VERBOSE(name, lv)
Definition Log.h:71
basic::train::TrainableTensor TrainableTensor
Definition Tensor.h:42
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:33

References onert::backend::train::TrainableBackendContext::_tdata, onert::backend::train::TrainableBackendContext::data(), onert::backend::train::TrainableBackendContext::external_operands(), onert::backend::basic::train::TrainableTensor::fillBuffer(), onert::ir::Operand::isConstant(), onert::util::ObjectManager< Index, Object >::iterate(), onert::ir::train::TrainableGraph::operands(), onert::ir::Operand::shareData(), onert::backend::train::TrainableBackendContext::tensor_registry(), onert::backend::train::TrainableBackendContext::trainable_graph(), and VERBOSE.

◆ operator=()

BackendContext & onert::backend::train::BackendContext::operator= ( const BackendContext )
delete

◆ optimizer()

const exec::train::optimizer::Optimizer * onert::backend::train::BackendContext::optimizer ( ) const
inline

Field Documentation

◆ kernel_gen

std::shared_ptr<KernelGenerator> onert::backend::train::BackendContext::kernel_gen

Definition at line 84 of file BackendContext.h.


The documentation for this class was generated from the following files: