ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::builtin::train::BackendContext Class Reference

#include <BackendContext.h>

Collaboration diagram for onert::backend::builtin::train::BackendContext:

Public Member Functions

 BackendContext (const backend::train::ITrainableBackend *backend, std::unique_ptr< backend::train::TrainableContextData > &&data, std::shared_ptr< backend::train::ITensorRegistry > tensor_registry=nullptr, std::shared_ptr< TensorBuilder > tensor_builder=nullptr, std::shared_ptr< KernelGenerator > kernel_gen=nullptr)
 
backend::train::FunctionMap gen () override
 
std::shared_ptr< ExternalContextexternal_context ()
 
- Public Member Functions inherited from onert::backend::train::TrainableBackendContext
 TrainableBackendContext (const ITrainableBackend *backend, std::unique_ptr< TrainableContextData > &&tdata, std::shared_ptr< ITensorRegistry > tensor_registry=nullptr)
 
virtual ~TrainableBackendContext ()=default
 
const ir::train::TrainableGraphtrainable_graph () const
 
const TrainableContextDatadata () const
 
const ITrainableBackendbackend () const
 
const util::Set< ir::OperandIndex > & external_operands () const
 
const ir::OperandIndexMap< ir::Layout > & operand_layouts () const
 
std::shared_ptr< ITensorRegistrytensor_registry ()
 

Data Fields

std::shared_ptr< KernelGeneratorkernel_gen
 

Additional Inherited Members

- Protected Attributes inherited from onert::backend::train::TrainableBackendContext
std::unique_ptr< TrainableContextData_tdata
 
std::shared_ptr< ITensorRegistry_tensor_registry
 

Detailed Description

Definition at line 35 of file BackendContext.h.

Constructor & Destructor Documentation

◆ BackendContext()

onert::backend::builtin::train::BackendContext::BackendContext ( const backend::train::ITrainableBackend backend,
std::unique_ptr< backend::train::TrainableContextData > &&  data,
std::shared_ptr< backend::train::ITensorRegistry tensor_registry = nullptr,
std::shared_ptr< TensorBuilder tensor_builder = nullptr,
std::shared_ptr< KernelGenerator kernel_gen = nullptr 
)
inline

Definition at line 38 of file BackendContext.h.

43 : backend::train::TrainableBackendContext(backend, std::move(data), tensor_registry),
44 kernel_gen{kernel_gen}, _external_context(new ExternalContext),
45 _tensor_builder{tensor_builder}
46 {
47 }
std::shared_ptr< KernelGenerator > kernel_gen
std::shared_ptr< ITensorRegistry > tensor_registry()

Member Function Documentation

◆ external_context()

std::shared_ptr< ExternalContext > onert::backend::builtin::train::BackendContext::external_context ( )
inline

Definition at line 52 of file BackendContext.h.

52{ return _external_context; }

◆ gen()

backend::train::FunctionMap onert::backend::builtin::train::BackendContext::gen ( )
overridevirtual

Implements onert::backend::train::TrainableBackendContext.

Definition at line 31 of file BackendContext.cc.

32{
33 // For now, there is no need to generate tensors for forwarding and backwarding.
34 // builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`.
35 // `Permute`: Tensor generation is not required.
36 // `IF`, `WHILE`: Not supported yet
37
39
40 for (auto &&op_ind : _tdata->op_order)
41 {
42 auto tn_seq = kernel_gen->generate(op_ind);
43 fn_map.emplace(op_ind, std::move(tn_seq));
44 }
45
47 [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
48 if (!external_operands().contains(ind) && operand.isConstant())
49 {
50 throw std::runtime_error(
51 "BackendContext: builtin backend does not support updatable weights yet");
52 }
53 });
54
55 // TODO Enable prepare()
56 // for (auto &&it : fn_map)
57 // {
58 // auto &fn_seq = it.second;
59 // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
60 // }
61
62 return fn_map;
63}
const ir::train::TrainableGraph * trainable_graph() const
const util::Set< ir::OperandIndex > & external_operands() const
std::unique_ptr< TrainableContextData > _tdata
const Operands & operands() const override
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
std::unordered_map< ir::OperationIndex, std::unique_ptr< exec::train::TrainableFnSequence > > FunctionMap
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:35

References onert::backend::train::TrainableBackendContext::_tdata, onert::backend::train::TrainableBackendContext::external_operands(), onert::ir::Operand::isConstant(), onert::util::ObjectManager< Index, Object >::iterate(), kernel_gen, onert::ir::train::TrainableGraph::operands(), and onert::backend::train::TrainableBackendContext::trainable_graph().

Field Documentation

◆ kernel_gen

std::shared_ptr<KernelGenerator> onert::backend::builtin::train::BackendContext::kernel_gen

Definition at line 56 of file BackendContext.h.

Referenced by gen().


The documentation for this class was generated from the following files: