ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::backend::basic::train Namespace Reference

Data Structures

class  TrainableTensor
 

Functions

template<typename TensorBuilder >
ITensorRegistrygenTensors (backend::train::TrainableBackendContext &ctx, const std::shared_ptr< TensorBuilder > &tensor_builder)
 

Function Documentation

◆ genTensors()

template<typename TensorBuilder >
ITensorRegistry * onert::backend::basic::train::genTensors ( backend::train::TrainableBackendContext ctx,
const std::shared_ptr< TensorBuilder > &  tensor_builder 
)

Definition at line 28 of file TrainableBackendContextHelpers.h.

30{
31 const auto &tgraph = *ctx.trainable_graph();
32
33 tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
34 if (ctx.external_operands().contains(ind))
35 return;
36 tensor_builder->registerTensorInfo(ind, obj.info());
37 });
38
39 // For the executors that does not have fixed linear execution order:
40 // To make tensors never be deallocated, this is a workaround to use static memory planner
41 tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) {
42 if (tensor_builder->isRegistered(ind))
43 tensor_builder->notifyFirstUse(ind);
44 });
45
46 tensor_builder->allocate();
47
48 return ctx.tensor_registry().get();
49}
std::shared_ptr< ITensorRegistry > tensor_registry()
const ir::train::TrainableGraph * trainable_graph() const
const util::Set< ir::OperandIndex > & external_operands() const
const Operands & operands() const override
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.

References onert::backend::train::TrainableBackendContext::external_operands(), onert::util::ObjectManager< Index, Object >::iterate(), onert::ir::train::TrainableGraph::operands(), onert::backend::train::TrainableBackendContext::tensor_registry(), and onert::backend::train::TrainableBackendContext::trainable_graph().