ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::builtin::train::BackendContext Class Reference

#include <BackendContext.h>

Collaboration diagram for onert::backend::builtin::train::BackendContext:

Public Member Functions

 BackendContext (const backend::train::ITrainableBackend *backend, std::unique_ptr< backend::train::TrainableContextData > &&data, std::shared_ptr< backend::train::ITensorRegistry > tensor_registry=nullptr, std::shared_ptr< TensorBuilder > tensor_builder=nullptr, std::shared_ptr< KernelGenerator > kernel_gen=nullptr)
 
backend::train::FunctionMap gen () override
 
std::shared_ptr< ExternalContextexternal_context ()
 
- Public Member Functions inherited from onert::backend::train::TrainableBackendContext
 TrainableBackendContext (const ITrainableBackend *backend, std::unique_ptr< TrainableContextData > &&tdata, std::shared_ptr< ITensorRegistry > tensor_registry=nullptr)
 
virtual ~TrainableBackendContext ()=default
 
const ir::train::TrainableGraphtrainable_graph () const
 
const TrainableContextDatadata () const
 
const ITrainableBackendbackend () const
 
const util::Set< ir::OperandIndex > & external_operands () const
 
const ir::OperandIndexMap< ir::Layout > & operand_layouts () const
 
std::shared_ptr< ITensorRegistrytensor_registry ()
 

Data Fields

std::shared_ptr< KernelGeneratorkernel_gen
 

Additional Inherited Members

- Protected Attributes inherited from onert::backend::train::TrainableBackendContext
std::unique_ptr< TrainableContextData_tdata
 
std::shared_ptr< ITensorRegistry_tensor_registry
 

Detailed Description

Definition at line 29 of file BackendContext.h.

Constructor & Destructor Documentation

◆ BackendContext()

onert::backend::builtin::train::BackendContext::BackendContext ( const backend::train::ITrainableBackend backend,
std::unique_ptr< backend::train::TrainableContextData > &&  data,
std::shared_ptr< backend::train::ITensorRegistry tensor_registry = nullptr,
std::shared_ptr< TensorBuilder tensor_builder = nullptr,
std::shared_ptr< KernelGenerator kernel_gen = nullptr 
)
inline

Definition at line 32 of file BackendContext.h.

37 : backend::train::TrainableBackendContext(backend, std::move(data), tensor_registry),
38 kernel_gen{kernel_gen}, _external_context(new ExternalContext),
39 _tensor_builder{tensor_builder}
40 {
41 }
std::shared_ptr< KernelGenerator > kernel_gen
std::shared_ptr< ITensorRegistry > tensor_registry()

Member Function Documentation

◆ external_context()

std::shared_ptr< ExternalContext > onert::backend::builtin::train::BackendContext::external_context ( )
inline

Definition at line 46 of file BackendContext.h.

46{ return _external_context; }

◆ gen()

backend::train::FunctionMap onert::backend::builtin::train::BackendContext::gen ( )
overridevirtual

Implements onert::backend::train::TrainableBackendContext.

Definition at line 25 of file BackendContext.cc.

26{
27 // For now, there is no need to generate tensors for forwarding and backwarding.
28 // builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`.
29 // `Permute`: Tensor generation is not required.
30 // `IF`, `WHILE`: Not supported yet
31
33
34 for (auto &&op_ind : _tdata->op_order)
35 {
36 auto tn_seq = kernel_gen->generate(op_ind);
37 fn_map.emplace(op_ind, std::move(tn_seq));
38 }
39
41 [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
42 if (!external_operands().contains(ind) && operand.isConstant())
43 {
44 throw std::runtime_error(
45 "BackendContext: builtin backend does not support updatable weights yet");
46 }
47 });
48
49 // TODO Enable prepare()
50 // for (auto &&it : fn_map)
51 // {
52 // auto &fn_seq = it.second;
53 // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
54 // }
55
56 return fn_map;
57}
const ir::train::TrainableGraph * trainable_graph() const
const util::Set< ir::OperandIndex > & external_operands() const
std::unique_ptr< TrainableContextData > _tdata
const Operands & operands() const override
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
std::unordered_map< ir::OperationIndex, std::unique_ptr< exec::train::TrainableFnSequence > > FunctionMap
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:33

References onert::backend::train::TrainableBackendContext::_tdata, onert::backend::train::TrainableBackendContext::external_operands(), onert::ir::Operand::isConstant(), onert::util::ObjectManager< Index, Object >::iterate(), kernel_gen, onert::ir::train::TrainableGraph::operands(), and onert::backend::train::TrainableBackendContext::trainable_graph().

Field Documentation

◆ kernel_gen

std::shared_ptr<KernelGenerator> onert::backend::builtin::train::BackendContext::kernel_gen

Definition at line 50 of file BackendContext.h.

Referenced by gen().


The documentation for this class was generated from the following files: