ONE - On-device Neural Engine
|
#include <TrainableGraph.h>
Public Member Functions | |
TrainableGraph () | |
Construct a new Trainable Graph object. | |
TrainableGraph (const TrainableGraph &tgraph) | |
TrainableGraph (const Graph &graph) | |
~TrainableGraph ()=default | |
OperandIndex | addOperand (const Shape &shape, const TypeInfo &type) |
OperandIndex | addOperand (OperandIndex index, std::unique_ptr< Operand > &&operand) |
Add an operand to the graph with the given index and object. | |
OperationIndex | addOperation (std::unique_ptr< ITrainableOperation > &&operation) |
Add a new trainable operation to the graph. | |
OperationIndex | replaceOperation (OperationIndex index, std::unique_ptr< ITrainableOperation > &&operation) |
Replace a trainable operation which the graph already has. | |
OperandIndex | addBackwardOperand (OperandIndex index, std::unique_ptr< Operand > &&bwd_operand) |
Add an operand for backwarding to the graph with the given index and object. | |
void | changeShape (const OperandIndex &ind, const ir::Shape &new_shape) override |
void | changeBackwardShape (const OperandIndex &ind, const ir::Shape &new_shape) |
void | addInput (const OperandIndex &ind, const std::string &name="") |
void | addOutput (const OperandIndex &ind, const std::string &name="") |
void | addLoss (const OperandIndex &loss_ind, const IOIndex &pred_io_ind) |
void | verify () const |
void | removeOperand (const OperandIndex &ind) |
void | setInputs (OperandIndexSequence inputs, std::unordered_map< std::string, IOIndex > name_to_input) |
void | setOutputs (OperandIndexSequence outputs, std::unordered_map< std::string, IOIndex > name_to_output) |
void | enableBackward (const OperationIndex &index) |
void | disableBackward (const OperationIndex &index) |
void | setTrainingUseDefs (const UseDefChains &training_defuses) |
const OperandIndexSequence & | getInputs () const override |
const OperandIndexSequence & | getOutputs () const override |
IOIndex | getInputIndex (const std::string &name) const override |
IOIndex | getOutputIndex (const std::string &name) const override |
const Operands & | operands () const override |
Operands & | operands () |
const Operations & | operations () const override |
const Operands & | backward_operands () const |
OperandIndex | getLossIndex (const IOIndex &pred_io_ind) const |
const Graph & | graph () const |
const ITrainableOperation & | operation (OperationIndex index) const |
const UseDefChains & | trainingUseDefs () const |
std::vector< ir::OperationIndex > | topolSortOperations () const |
std::vector< ir::OperationIndex > | btopolSortOperations () const |
std::vector< ir::OperationIndex > | essentialBackwardOrder () const |
std::vector< ir::OperationIndex > | truncateBackwardOrder (const std::vector< ir::OperationIndex > &backward_order) const |
Truncate the backward order of operations in accordance with the alive condition whether the corresponding operation has trainable parameters. | |
std::vector< ir::OperationIndex > | truncateBackwardOrder (std::vector< ir::OperationIndex > backward_order, std::function< bool(const ir::OperationIndex &)> truncating_cond) const |
Truncate the backward order of operations in accordance with the given alive condition. | |
void | updateGraphDependency () |
Public Member Functions inherited from onert::ir::IGraph | |
virtual | ~IGraph ()=default |
Definition at line 35 of file TrainableGraph.h.
|
explicit |
Construct a new Trainable Graph object.
graph |
Definition at line 82 of file TrainableGraph.cc.
|
explicit |
Definition at line 84 of file TrainableGraph.cc.
References onert::ir::clone(), onert::util::ObjectManager< Index, Object >::iterate(), operations(), and replaceOperation().
|
explicit |
Definition at line 94 of file TrainableGraph.cc.
|
default |
OperandIndex onert::ir::train::TrainableGraph::addBackwardOperand | ( | OperandIndex | index, |
std::unique_ptr< Operand > && | bwd_operand | ||
) |
Add an operand for backwarding to the graph with the given index and object.
If the given index is available, it succeeds. And bwd_operand
is moved which invalidates the caller's pointer. If the given index is already taken, it fails. And bwd_operand
will not be moved so the caller's pointer will be still valid.
[in] | index | Index to be added |
[in] | bwd_operand | Backward operand to be added |
index
if successful, UNDEFINED otherwise Definition at line 117 of file TrainableGraph.cc.
References onert::util::ObjectManager< Index, Object >::push().
void onert::ir::train::TrainableGraph::addInput | ( | const OperandIndex & | ind, |
const std::string & | name = "" |
||
) |
Definition at line 144 of file TrainableGraph.cc.
References onert::ir::Graph::addInput().
Referenced by onert::compiler::train::pass::LossInsertionPass::run().
void onert::ir::train::TrainableGraph::addLoss | ( | const OperandIndex & | loss_ind, |
const IOIndex & | pred_io_ind | ||
) |
Definition at line 369 of file TrainableGraph.cc.
Referenced by onert::compiler::train::pass::LossInsertionPass::run().
OperandIndex onert::ir::train::TrainableGraph::addOperand | ( | const Shape & | shape, |
const TypeInfo & | type | ||
) |
Definition at line 96 of file TrainableGraph.cc.
References onert::ir::Graph::addOperand().
Referenced by onert::compiler::train::pass::LossInsertionPass::run().
OperandIndex onert::ir::train::TrainableGraph::addOperand | ( | OperandIndex | index, |
std::unique_ptr< Operand > && | operand | ||
) |
Add an operand to the graph with the given index and object.
If the given index is available, it succeeds. And operand
is moved which invalidates the caller's pointer. If the given index is already taken, it fails. And operand
will not be moved so the caller's pointer will be still valid.
[in] | index | Index to be added |
[in] | operand | Operand to be added |
index
if successful, UNDEFINED otherwise Definition at line 101 of file TrainableGraph.cc.
References onert::ir::Graph::addOperand().
OperationIndex onert::ir::train::TrainableGraph::addOperation | ( | std::unique_ptr< ITrainableOperation > && | operation | ) |
Add a new trainable operation to the graph.
If the given operation
has at least one invalid operand index, it fails. And operation
will not be moved so the caller's pointer will be still valid.
operation | Operation to be added |
index
if successful, UNDEFINED otherwise Definition at line 106 of file TrainableGraph.cc.
References onert::ir::Graph::addOperation(), and operation().
Referenced by onert::compiler::train::pass::LossInsertionPass::run().
void onert::ir::train::TrainableGraph::addOutput | ( | const OperandIndex & | ind, |
const std::string & | name = "" |
||
) |
Definition at line 149 of file TrainableGraph.cc.
References onert::ir::Graph::addOutput().
|
inline |
Definition at line 126 of file TrainableGraph.h.
std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::btopolSortOperations | ( | ) | const |
Definition at line 275 of file TrainableGraph.cc.
References onert::util::Set< Element >::add(), onert::util::ObjectManager< Index, Object >::at(), onert::util::Set< Element >::contains(), onert::ir::DUPLICATED, onert::ir::Operand::getDef(), onert::ir::IOperation::getInputs(), onert::util::ObjectManager< Index, Object >::iterate(), onert::ir::IOperation::opcode(), operands(), operations(), onert::util::Set< Element >::remove(), and onert::ir::UNDEFINED.
Referenced by essentialBackwardOrder().
void onert::ir::train::TrainableGraph::changeBackwardShape | ( | const OperandIndex & | ind, |
const ir::Shape & | new_shape | ||
) |
Definition at line 138 of file TrainableGraph.cc.
References onert::util::ObjectManager< Index, Object >::at(), onert::util::ObjectManager< Index, Object >::exist(), onert::ir::Operand::info(), and onert::ir::OperandInfo::shape().
|
overridevirtual |
Implements onert::ir::IGraph.
Definition at line 133 of file TrainableGraph.cc.
References onert::ir::Graph::changeShape().
void onert::ir::train::TrainableGraph::disableBackward | ( | const OperationIndex & | index | ) |
Definition at line 188 of file TrainableGraph.cc.
References onert::util::ObjectManager< Index, Object >::at(), onert::ir::train::ITrainableOperation::disableBackward(), and onert::ir::Graph::operations().
void onert::ir::train::TrainableGraph::enableBackward | ( | const OperationIndex & | index | ) |
Definition at line 181 of file TrainableGraph.cc.
References onert::util::ObjectManager< Index, Object >::at(), and onert::ir::Graph::operations().
std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::essentialBackwardOrder | ( | ) | const |
Definition at line 314 of file TrainableGraph.cc.
References btopolSortOperations(), onert::ir::train::ITrainableOperation::isRequiredForBackward(), operation(), and truncateBackwardOrder().
Referenced by onert::backend::train::TensorPlanner::planBackPropTensors(), onert::backend::train::TensorPlanner::planDisposableBackPropTensors(), onert::backend::train::TensorPlanner::planGradientTensors(), onert::backend::train::TensorPlanner::planLayerScopeTensors(), and onert::backend::train::TensorPlanner::planNonConstTensors().
|
overridevirtual |
Implements onert::ir::IGraph.
Definition at line 123 of file TrainableGraph.cc.
References onert::ir::Graph::getInputIndex().
|
inlineoverridevirtual |
Implements onert::ir::IGraph.
Definition at line 119 of file TrainableGraph.h.
References onert::ir::Graph::getInputs().
Referenced by onert::exec::train::TrainableExecutor::TrainableExecutor().
OperandIndex onert::ir::train::TrainableGraph::getLossIndex | ( | const IOIndex & | pred_io_ind | ) | const |
Definition at line 374 of file TrainableGraph.cc.
Referenced by onert::exec::train::TrainableExecutor::getLoss().
|
overridevirtual |
Implements onert::ir::IGraph.
Definition at line 128 of file TrainableGraph.cc.
References onert::ir::Graph::getOutputIndex().
|
inlineoverridevirtual |
Implements onert::ir::IGraph.
Definition at line 120 of file TrainableGraph.h.
References onert::ir::Graph::getOutputs().
Referenced by onert::compiler::train::pass::LossInsertionPass::run(), and onert::exec::train::TrainableExecutor::TrainableExecutor().
|
inline |
Definition at line 128 of file TrainableGraph.h.
Referenced by TopologicalSortHelper.TopologicalSortHelper::add_edge(), onert::exec::train::TrainableExecutor::graph(), onert::compiler::train::LoweredTrainableGraph::graph(), onert::compiler::train::LoweredTrainableGraph::graph(), onert::ir::train::UseDefGenerator::operator()(), and TopologicalSortHelper.TopologicalSortHelper::sort_util().
|
inline |
|
inlineoverridevirtual |
Implements onert::ir::IGraph.
Definition at line 123 of file TrainableGraph.h.
References onert::ir::Graph::operands().
Referenced by btopolSortOperations(), onert::backend::train::BackendContext::gen(), onert::backend::builtin::train::BackendContext::gen(), onert::backend::basic::train::genTensors(), onert::backend::train::TensorPlanner::planBackPropTensors(), onert::compiler::train::pass::LossInsertionPass::run(), truncateBackwardOrder(), onert::backend::train::KernelGenerator::visit(), onert::backend::train::KernelGenerator::visit(), onert::backend::train::KernelGenerator::visit(), and onert::ir::train::UseDefGenerator::visit().
const ITrainableOperation & onert::ir::train::TrainableGraph::operation | ( | OperationIndex | index | ) | const |
Definition at line 175 of file TrainableGraph.cc.
References onert::util::ObjectManager< Index, Object >::at(), and onert::ir::Graph::operations().
Referenced by addOperation(), essentialBackwardOrder(), onert::backend::builtin::train::KernelGenerator::generate(), onert::backend::train::KernelGenerator::generate(), onert::compiler::train::StaticBackwardShapeInferer::infer(), onert::backend::train::TensorPlanner::planDisposableBackPropTensors(), replaceOperation(), truncateBackwardOrder(), and onert::ir::train::UseDefGenerator::UseDefGenerator().
|
inlineoverridevirtual |
Implements onert::ir::IGraph.
Definition at line 125 of file TrainableGraph.h.
References onert::ir::Graph::operations().
Referenced by btopolSortOperations(), onert::backend::train::KernelGenerator::KernelGenerator(), onert::backend::train::TensorPlanner::planBackPropTensors(), onert::backend::train::TensorPlanner::planGradientTensors(), onert::backend::train::TensorPlanner::planNonConstTensors(), onert::compiler::train::pass::LossInsertionPass::run(), TrainableGraph(), truncateBackwardOrder(), and verify().
void onert::ir::train::TrainableGraph::removeOperand | ( | const OperandIndex & | ind | ) |
Definition at line 173 of file TrainableGraph.cc.
References onert::ir::Graph::removeOperand().
OperationIndex onert::ir::train::TrainableGraph::replaceOperation | ( | OperationIndex | index, |
std::unique_ptr< ITrainableOperation > && | operation | ||
) |
Replace a trainable operation which the graph already has.
If the given index
is available, it succeeds. And operation
is moved which invalidates the caller's pointer. If the given operation
has at least one invalid operand index, it fails. And operation
will not be moved so the caller's pointer will be still valid.
No information in the graph is changed except for replacing an operation.
operation | Operation to be added |
index
if successful, UNDEFINED otherwise Definition at line 111 of file TrainableGraph.cc.
References operation(), and onert::ir::Graph::replaceOperation().
Referenced by TrainableGraph().
void onert::ir::train::TrainableGraph::setInputs | ( | OperandIndexSequence | inputs, |
std::unordered_map< std::string, IOIndex > | name_to_input | ||
) |
void onert::ir::train::TrainableGraph::setOutputs | ( | OperandIndexSequence | outputs, |
std::unordered_map< std::string, IOIndex > | name_to_output | ||
) |
void onert::ir::train::TrainableGraph::setTrainingUseDefs | ( | const UseDefChains & | training_defuses | ) |
Definition at line 194 of file TrainableGraph.cc.
Referenced by updateGraphDependency().
std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::topolSortOperations | ( | ) | const |
Definition at line 267 of file TrainableGraph.cc.
References onert::ir::Graph::topolSortOperations().
Referenced by onert::backend::train::TensorPlanner::planLayerScopeTensors(), onert::backend::train::TensorPlanner::planNonConstTensors(), and onert::ir::train::UseDefGenerator::UseDefGenerator().
|
inline |
Definition at line 132 of file TrainableGraph.h.
Referenced by onert::backend::train::TensorPlanner::planBackPropTensors(), onert::backend::train::TensorPlanner::planGradientTensors(), onert::backend::train::TensorPlanner::planNonConstTensors(), and onert::backend::train::TensorPlanner::planTrainableTensors().
std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::truncateBackwardOrder | ( | const std::vector< ir::OperationIndex > & | backward_order | ) | const |
Truncate the backward order of operations in accordance with the alive condition whether the corresponding operation has trainable parameters.
backward_order | The order of operations in a backward graph |
Definition at line 360 of file TrainableGraph.cc.
References operation(), and truncateBackwardOrder().
Referenced by essentialBackwardOrder(), and truncateBackwardOrder().
std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::truncateBackwardOrder | ( | std::vector< ir::OperationIndex > | backward_order, |
std::function< bool(const ir::OperationIndex &)> | truncating_cond | ||
) | const |
Truncate the backward order of operations in accordance with the given alive condition.
backward_order | The order of operations in a backward graph |
alive_cond | The alive condition to stop the backward order |
Definition at line 325 of file TrainableGraph.cc.
References onert::util::ObjectManager< Index, Object >::at(), onert::ir::IOperation::getOutputs(), operands(), and operations().
void onert::ir::train::TrainableGraph::updateGraphDependency | ( | ) |
Definition at line 380 of file TrainableGraph.cc.
References setTrainingUseDefs(), and onert::ir::Graph::verify().
void onert::ir::train::TrainableGraph::verify | ( | void | ) | const |
Definition at line 154 of file TrainableGraph.cc.
References onert::util::ObjectManager< Index, Object >::iterate(), onert::ir::IOperation::name(), operations(), and onert::ir::Graph::verify().