ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::ir::train::TrainableGraph Class Reference

#include <TrainableGraph.h>

Collaboration diagram for onert::ir::train::TrainableGraph:

Public Member Functions

 TrainableGraph ()
 Construct a new Trainable Graph object.
 
 TrainableGraph (const TrainableGraph &tgraph)
 
 TrainableGraph (const Graph &graph)
 
 ~TrainableGraph ()=default
 
OperandIndex addOperand (const Shape &shape, const TypeInfo &type)
 
OperandIndex addOperand (OperandIndex index, std::unique_ptr< Operand > &&operand)
 Add an operand to the graph with the given index and object.
 
OperationIndex addOperation (std::unique_ptr< ITrainableOperation > &&operation)
 Add a new trainable operation to the graph.
 
OperationIndex replaceOperation (OperationIndex index, std::unique_ptr< ITrainableOperation > &&operation)
 Replace a trainable operation which the graph already has.
 
OperandIndex addBackwardOperand (OperandIndex index, std::unique_ptr< Operand > &&bwd_operand)
 Add an operand for backwarding to the graph with the given index and object.
 
void changeShape (const OperandIndex &ind, const ir::Shape &new_shape) override
 
void changeBackwardShape (const OperandIndex &ind, const ir::Shape &new_shape)
 
void addInput (const OperandIndex &ind, const std::string &name="")
 
void addOutput (const OperandIndex &ind, const std::string &name="")
 
void addLoss (const OperandIndex &loss_ind, const IOIndex &pred_io_ind)
 
void verify () const
 
void removeOperand (const OperandIndex &ind)
 
void setInputs (OperandIndexSequence inputs, std::unordered_map< std::string, IOIndex > name_to_input)
 
void setOutputs (OperandIndexSequence outputs, std::unordered_map< std::string, IOIndex > name_to_output)
 
void enableBackward (const OperationIndex &index)
 
void disableBackward (const OperationIndex &index)
 
void setTrainingUseDefs (const UseDefChains &training_defuses)
 
const OperandIndexSequencegetInputs () const override
 
const OperandIndexSequencegetOutputs () const override
 
IOIndex getInputIndex (const std::string &name) const override
 
IOIndex getOutputIndex (const std::string &name) const override
 
const Operandsoperands () const override
 
Operandsoperands ()
 
const Operationsoperations () const override
 
const Operandsbackward_operands () const
 
OperandIndex getLossIndex (const IOIndex &pred_io_ind) const
 
const Graphgraph () const
 
const ITrainableOperationoperation (OperationIndex index) const
 
const UseDefChainstrainingUseDefs () const
 
std::vector< ir::OperationIndextopolSortOperations () const
 
std::vector< ir::OperationIndexbtopolSortOperations () const
 
std::vector< ir::OperationIndexessentialBackwardOrder () const
 
std::vector< ir::OperationIndextruncateBackwardOrder (const std::vector< ir::OperationIndex > &backward_order) const
 Truncate the backward order of operations in accordance with the alive condition whether the corresponding operation has trainable parameters.
 
std::vector< ir::OperationIndextruncateBackwardOrder (std::vector< ir::OperationIndex > backward_order, std::function< bool(const ir::OperationIndex &)> truncating_cond) const
 Truncate the backward order of operations in accordance with the given alive condition.
 
void updateGraphDependency ()
 
- Public Member Functions inherited from onert::ir::IGraph
virtual ~IGraph ()=default
 

Detailed Description

Definition at line 31 of file TrainableGraph.h.

Constructor & Destructor Documentation

◆ TrainableGraph() [1/3]

onert::ir::train::TrainableGraph::TrainableGraph ( )
explicit

Construct a new Trainable Graph object.

Parameters
graph

Definition at line 78 of file TrainableGraph.cc.

78: _graph{} {}

◆ TrainableGraph() [2/3]

onert::ir::train::TrainableGraph::TrainableGraph ( const TrainableGraph tgraph)
explicit

Definition at line 80 of file TrainableGraph.cc.

81 : _graph{tgraph._graph}, _backward_operands{tgraph._backward_operands},
82 _training_defuses{tgraph._training_defuses}, _losses{tgraph._losses}
83{
84 tgraph.operations().iterate(
85 [&](const onert::ir::OperationIndex &index, const onert::ir::IOperation &op) {
86 replaceOperation(index, dynamic_cast<const ITrainableOperation &>(op).clone());
87 });
88}
OperationIndex replaceOperation(OperationIndex index, std::unique_ptr< ITrainableOperation > &&operation)
Replace a trainable operation which the graph already has.
const Operations & operations() const override
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
std::unique_ptr< Operation > clone(const IOperation &operation)

References onert::ir::clone(), onert::util::ObjectManager< Index, Object >::iterate(), operations(), and replaceOperation().

◆ TrainableGraph() [3/3]

onert::ir::train::TrainableGraph::TrainableGraph ( const Graph graph)
explicit

Definition at line 90 of file TrainableGraph.cc.

90: _graph{graph} {}

◆ ~TrainableGraph()

onert::ir::train::TrainableGraph::~TrainableGraph ( )
default

Member Function Documentation

◆ addBackwardOperand()

OperandIndex onert::ir::train::TrainableGraph::addBackwardOperand ( OperandIndex  index,
std::unique_ptr< Operand > &&  bwd_operand 
)

Add an operand for backwarding to the graph with the given index and object.

If the given index is available, it succeeds. And bwd_operand is moved which invalidates the caller's pointer. If the given index is already taken, it fails. And bwd_operand will not be moved so the caller's pointer will be still valid.

Parameters
[in]indexIndex to be added
[in]bwd_operandBackward operand to be added
Returns
OperandIndex index if successful, UNDEFINED otherwise

Definition at line 113 of file TrainableGraph.cc.

115{
116 return _backward_operands.push(std::move(bwd_operand), index);
117}
Index push(std::unique_ptr< Object > &&object, Index index)
Put the object in the container with given index.

References onert::util::ObjectManager< Index, Object >::push().

◆ addInput()

void onert::ir::train::TrainableGraph::addInput ( const OperandIndex ind,
const std::string &  name = "" 
)

Definition at line 140 of file TrainableGraph.cc.

141{
142 _graph.addInput(ind, name);
143}
void addInput(const OperandIndex &ind, const std::string &name="")
Definition Graph.cc:121

References onert::ir::Graph::addInput().

Referenced by onert::compiler::train::pass::LossInsertionPass::run().

◆ addLoss()

void onert::ir::train::TrainableGraph::addLoss ( const OperandIndex loss_ind,
const IOIndex pred_io_ind 
)

Definition at line 365 of file TrainableGraph.cc.

366{
367 _losses.emplace(pred_ioind, loss_ind);
368}

Referenced by onert::compiler::train::pass::LossInsertionPass::run().

◆ addOperand() [1/2]

OperandIndex onert::ir::train::TrainableGraph::addOperand ( const Shape shape,
const TypeInfo type 
)

Definition at line 92 of file TrainableGraph.cc.

93{
94 return _graph.addOperand(shape, type);
95}
OperandIndex addOperand(const Shape &shape, const TypeInfo &type)
Definition Graph.cc:33

References onert::ir::Graph::addOperand().

Referenced by onert::compiler::train::pass::LossInsertionPass::run().

◆ addOperand() [2/2]

OperandIndex onert::ir::train::TrainableGraph::addOperand ( OperandIndex  index,
std::unique_ptr< Operand > &&  operand 
)

Add an operand to the graph with the given index and object.

If the given index is available, it succeeds. And operand is moved which invalidates the caller's pointer. If the given index is already taken, it fails. And operand will not be moved so the caller's pointer will be still valid.

Parameters
[in]indexIndex to be added
[in]operandOperand to be added
Returns
OperandIndex index if successful, UNDEFINED otherwise

Definition at line 97 of file TrainableGraph.cc.

98{
99 return _graph.addOperand(index, std::move(operand));
100}

References onert::ir::Graph::addOperand().

◆ addOperation()

OperationIndex onert::ir::train::TrainableGraph::addOperation ( std::unique_ptr< ITrainableOperation > &&  operation)

Add a new trainable operation to the graph.

If the given operation has at least one invalid operand index, it fails. And operation will not be moved so the caller's pointer will be still valid.

Parameters
operationOperation to be added
Returns
OperationIndex index if successful, UNDEFINED otherwise

Definition at line 102 of file TrainableGraph.cc.

103{
104 return _graph.addOperation(std::move(operation));
105}
OperationIndex addOperation(std::unique_ptr< IOperation > &&node)
Definition Graph.cc:67

References onert::ir::Graph::addOperation(), and operation().

Referenced by onert::compiler::train::pass::LossInsertionPass::run().

◆ addOutput()

void onert::ir::train::TrainableGraph::addOutput ( const OperandIndex ind,
const std::string &  name = "" 
)

Definition at line 145 of file TrainableGraph.cc.

146{
147 _graph.addOutput(ind, name);
148}
void addOutput(const OperandIndex &ind, const std::string &name="")
Definition Graph.cc:128

References onert::ir::Graph::addOutput().

◆ backward_operands()

const Operands & onert::ir::train::TrainableGraph::backward_operands ( ) const
inline

Definition at line 122 of file TrainableGraph.h.

122{ return _backward_operands; }

◆ btopolSortOperations()

std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::btopolSortOperations ( ) const

Definition at line 271 of file TrainableGraph.cc.

272{
273 std::vector<ir::OperationIndex> ret;
274 util::Set<ir::OperationIndex> unvisited;
275 ir::OperationIndex loss_idx;
276 operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) {
277 unvisited.add(index);
278 if (op.opcode() == ir::OpCode::Loss)
279 {
280 assert(!loss_idx.valid()); // Should be only one loss
281 loss_idx = index;
282 }
283 });
284
285 std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs =
286 [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void {
287 if (!unvisited.contains(index))
288 return;
289 unvisited.remove(index);
290
291 for (const auto &input : op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
292 {
293 const auto &operand = operands().at(input);
294 const auto &def = operand.getDef();
295 if (!def.valid())
296 continue;
297 dfs(def, operations().at(def));
298 }
299
300 ret.push_back(index);
301 };
302
303 dfs(loss_idx, operations().at(loss_idx));
304 std::reverse(ret.begin(), ret.end());
305 validateBackwardTopologicalOrder(ret);
306
307 return ret;
308}
OperationIndex getDef() const
Definition Operand.h:51
const OperandIndexSequence & getInputs() const override
const Operands & operands() const override
const Object & at(const Index &index) const
Get the object that is associated with the given index.
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
virtual OpCode opcode() const =0

References onert::util::Set< Element >::add(), onert::util::ObjectManager< Index, Object >::at(), onert::util::Set< Element >::contains(), onert::ir::DUPLICATED, onert::ir::Operand::getDef(), onert::ir::IOperation::getInputs(), onert::util::ObjectManager< Index, Object >::iterate(), onert::ir::IOperation::opcode(), operands(), operations(), onert::util::Set< Element >::remove(), and onert::ir::UNDEFINED.

Referenced by essentialBackwardOrder().

◆ changeBackwardShape()

void onert::ir::train::TrainableGraph::changeBackwardShape ( const OperandIndex ind,
const ir::Shape new_shape 
)

Definition at line 134 of file TrainableGraph.cc.

135{
136 assert(_backward_operands.exist(index));
137 _backward_operands.at(index).info().shape(new_shape);
138}
const OperandInfo & info(void) const
Definition Operand.h:46
const Shape & shape() const
Return tensor shape.
Definition OperandInfo.h:93
bool exist(const Index &index) const
Get the object that is associated with the given index.

References onert::util::ObjectManager< Index, Object >::at(), onert::util::ObjectManager< Index, Object >::exist(), onert::ir::Operand::info(), and onert::ir::OperandInfo::shape().

◆ changeShape()

void onert::ir::train::TrainableGraph::changeShape ( const OperandIndex ind,
const ir::Shape new_shape 
)
overridevirtual

Implements onert::ir::IGraph.

Definition at line 129 of file TrainableGraph.cc.

130{
131 _graph.changeShape(index, new_shape);
132}
void changeShape(const OperandIndex &ind, const ir::Shape &new_shape) override
Definition Graph.cc:115

References onert::ir::Graph::changeShape().

◆ disableBackward()

void onert::ir::train::TrainableGraph::disableBackward ( const OperationIndex index)

Definition at line 184 of file TrainableGraph.cc.

185{
186 auto &op = dynamic_cast<ir::train::ITrainableOperation &>(_graph.operations().at(index));
187 op.disableBackward();
188}
const Operations & operations() const override
Definition Graph.h:112

References onert::util::ObjectManager< Index, Object >::at(), onert::ir::train::ITrainableOperation::disableBackward(), and onert::ir::Graph::operations().

◆ enableBackward()

void onert::ir::train::TrainableGraph::enableBackward ( const OperationIndex index)

Definition at line 177 of file TrainableGraph.cc.

178{
179 auto op = dynamic_cast<ir::train::ITrainableOperation *>(&_graph.operations().at(index));
180 assert(op);
181 op->enableBackward();
182}

References onert::util::ObjectManager< Index, Object >::at(), and onert::ir::Graph::operations().

◆ essentialBackwardOrder()

std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::essentialBackwardOrder ( ) const

Definition at line 310 of file TrainableGraph.cc.

311{
312 auto backward_order = btopolSortOperations();
313 // get rid of all nodes not reachable from a node with trainable parameters
314 backward_order = truncateBackwardOrder(backward_order, [&](const OperationIndex &index) {
315 return operation(index).isRequiredForBackward();
316 });
317
318 return truncateBackwardOrder(backward_order);
319}
virtual bool isRequiredForBackward() const =0
const ITrainableOperation & operation(OperationIndex index) const
std::vector< ir::OperationIndex > btopolSortOperations() const
std::vector< ir::OperationIndex > truncateBackwardOrder(const std::vector< ir::OperationIndex > &backward_order) const
Truncate the backward order of operations in accordance with the alive condition whether the correspo...

References btopolSortOperations(), onert::ir::train::ITrainableOperation::isRequiredForBackward(), operation(), and truncateBackwardOrder().

Referenced by onert::backend::train::TensorPlanner::planBackPropTensors(), onert::backend::train::TensorPlanner::planDisposableBackPropTensors(), onert::backend::train::TensorPlanner::planGradientTensors(), onert::backend::train::TensorPlanner::planLayerScopeTensors(), and onert::backend::train::TensorPlanner::planNonConstTensors().

◆ getInputIndex()

IOIndex onert::ir::train::TrainableGraph::getInputIndex ( const std::string &  name) const
overridevirtual

Implements onert::ir::IGraph.

Definition at line 119 of file TrainableGraph.cc.

120{
121 return _graph.getInputIndex(name);
122}
IOIndex getInputIndex(const std::string &name) const override
Definition Graph.cc:135

References onert::ir::Graph::getInputIndex().

◆ getInputs()

const OperandIndexSequence & onert::ir::train::TrainableGraph::getInputs ( ) const
inlineoverridevirtual

Implements onert::ir::IGraph.

Definition at line 115 of file TrainableGraph.h.

115{ return _graph.getInputs(); }
const OperandIndexSequence & getInputs() const override
Definition Graph.h:104

References onert::ir::Graph::getInputs().

Referenced by onert::exec::train::TrainableExecutor::TrainableExecutor().

◆ getLossIndex()

OperandIndex onert::ir::train::TrainableGraph::getLossIndex ( const IOIndex pred_io_ind) const

Definition at line 370 of file TrainableGraph.cc.

371{
372 auto itr = _losses.find(pred_ioind);
373 return (itr == _losses.end()) ? OperandIndex{} : itr->second;
374}

Referenced by onert::exec::train::TrainableExecutor::getLoss().

◆ getOutputIndex()

IOIndex onert::ir::train::TrainableGraph::getOutputIndex ( const std::string &  name) const
overridevirtual

Implements onert::ir::IGraph.

Definition at line 124 of file TrainableGraph.cc.

125{
126 return _graph.getOutputIndex(name);
127}
IOIndex getOutputIndex(const std::string &name) const override
Definition Graph.cc:141

References onert::ir::Graph::getOutputIndex().

◆ getOutputs()

const OperandIndexSequence & onert::ir::train::TrainableGraph::getOutputs ( ) const
inlineoverridevirtual

Implements onert::ir::IGraph.

Definition at line 116 of file TrainableGraph.h.

116{ return _graph.getOutputs(); }
const OperandIndexSequence & getOutputs() const override
Definition Graph.h:106

References onert::ir::Graph::getOutputs().

Referenced by onert::compiler::train::pass::LossInsertionPass::run(), and onert::exec::train::TrainableExecutor::TrainableExecutor().

◆ graph()

◆ operands() [1/2]

Operands & onert::ir::train::TrainableGraph::operands ( )
inline

Definition at line 120 of file TrainableGraph.h.

120{ return _graph.operands(); } // TODO Remove this non-const accessor
const Operands & operands() const override
Definition Graph.h:110

References onert::ir::Graph::operands().

◆ operands() [2/2]

◆ operation()

◆ operations()

◆ removeOperand()

void onert::ir::train::TrainableGraph::removeOperand ( const OperandIndex ind)

Definition at line 169 of file TrainableGraph.cc.

169{ _graph.removeOperand(ind); }
void removeOperand(const OperandIndex &ind)
Definition Graph.h:92

References onert::ir::Graph::removeOperand().

◆ replaceOperation()

OperationIndex onert::ir::train::TrainableGraph::replaceOperation ( OperationIndex  index,
std::unique_ptr< ITrainableOperation > &&  operation 
)

Replace a trainable operation which the graph already has.

If the given index is available, it succeeds. And operation is moved which invalidates the caller's pointer. If the given operation has at least one invalid operand index, it fails. And operation will not be moved so the caller's pointer will be still valid.

No information in the graph is changed except for replacing an operation.

Parameters
operationOperation to be added
Returns
OperationIndex index if successful, UNDEFINED otherwise

Definition at line 107 of file TrainableGraph.cc.

109{
110 return _graph.replaceOperation(index, std::move(operation));
111}
OperationIndex replaceOperation(OperationIndex index, std::unique_ptr< IOperation > &&operation)
Replace an operation which the graph already has.
Definition Graph.cc:92

References operation(), and onert::ir::Graph::replaceOperation().

Referenced by TrainableGraph().

◆ setInputs()

void onert::ir::train::TrainableGraph::setInputs ( OperandIndexSequence  inputs,
std::unordered_map< std::string, IOIndex name_to_input 
)

◆ setOutputs()

void onert::ir::train::TrainableGraph::setOutputs ( OperandIndexSequence  outputs,
std::unordered_map< std::string, IOIndex name_to_output 
)

◆ setTrainingUseDefs()

void onert::ir::train::TrainableGraph::setTrainingUseDefs ( const UseDefChains training_defuses)

Definition at line 190 of file TrainableGraph.cc.

191{
192 _training_defuses.clear();
193 // TODO Replace this loop with `std::unordered_map::insert_range` since C++23
194 for (const auto &[training_index, usedef] : training_defuses)
195 {
196 _training_defuses.emplace(training_index, usedef);
197 }
198}

Referenced by updateGraphDependency().

◆ topolSortOperations()

std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::topolSortOperations ( ) const

Definition at line 263 of file TrainableGraph.cc.

264{
265 auto ret = _graph.topolSortOperations();
266 validateForwardTopologicalOrder(ret);
267
268 return ret;
269}
std::vector< ir::OperationIndex > topolSortOperations() const
Definition Graph.cc:182

References onert::ir::Graph::topolSortOperations().

Referenced by onert::backend::train::TensorPlanner::planLayerScopeTensors(), onert::backend::train::TensorPlanner::planNonConstTensors(), and onert::ir::train::UseDefGenerator::UseDefGenerator().

◆ trainingUseDefs()

◆ truncateBackwardOrder() [1/2]

std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::truncateBackwardOrder ( const std::vector< ir::OperationIndex > &  backward_order) const

Truncate the backward order of operations in accordance with the alive condition whether the corresponding operation has trainable parameters.

Parameters
backward_orderThe order of operations in a backward graph

Definition at line 356 of file TrainableGraph.cc.

357{
358 return truncateBackwardOrder(backward_order, [&](const ir::OperationIndex &index) {
359 const auto &trainable_op = operation(index);
360
361 return trainable_op.hasTrainableParameter();
362 });
363}

References operation(), and truncateBackwardOrder().

Referenced by essentialBackwardOrder(), and truncateBackwardOrder().

◆ truncateBackwardOrder() [2/2]

std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::truncateBackwardOrder ( std::vector< ir::OperationIndex backward_order,
std::function< bool(const ir::OperationIndex &)>  truncating_cond 
) const

Truncate the backward order of operations in accordance with the given alive condition.

Parameters
backward_orderThe order of operations in a backward graph
alive_condThe alive condition to stop the backward order

Definition at line 321 of file TrainableGraph.cc.

324{
325 auto forward_order = backward_order;
326 std::reverse(forward_order.begin(), forward_order.end());
327 std::set<ir::OperationIndex> alive;
328
329 for (const auto &index : forward_order)
330 {
331 if (alive_cond(index))
332 alive.insert(index);
333
334 // TODO: replace this with `std::set::contains` after C++20
335 if (alive.find(index) != alive.end())
336 {
337 const auto &op = operations().at(index);
338 for (const auto &output : op.getOutputs())
339 {
340 const auto &operand = operands().at(output);
341 for (const auto &use : operand.getUses())
342 alive.insert(use);
343 }
344 }
345 }
346
347 // TODO: replace this with `std::erase_if(std::vector)` after C++20
348 backward_order.erase(
349 std::remove_if(backward_order.begin(), backward_order.end(),
350 [&](const auto &index) { return alive.find(index) == alive.end(); }),
351 backward_order.end());
352 return backward_order;
353}
const OperandIndexSequence & getOutputs() const override

References onert::util::ObjectManager< Index, Object >::at(), onert::ir::IOperation::getOutputs(), operands(), and operations().

◆ updateGraphDependency()

void onert::ir::train::TrainableGraph::updateGraphDependency ( )

Definition at line 376 of file TrainableGraph.cc.

377{
378 _graph.verify();
379
380 // Initialize training usedefs
382
383 disableUnusedBackwardNodes(_training_defuses, *this);
384
385 verifyTrainingUseDefs();
386}
void verify(void) const
Definition Graph.cc:147
void setTrainingUseDefs(const UseDefChains &training_defuses)

References setTrainingUseDefs(), and onert::ir::Graph::verify().

◆ verify()

void onert::ir::train::TrainableGraph::verify ( void  ) const

Definition at line 150 of file TrainableGraph.cc.

151{
152 _graph.verify();
153
155 try
156 {
157 [[maybe_unused]] const auto &casted_op =
158 dynamic_cast<const onert::ir::train::ITrainableOperation &>(op);
159 }
160 catch (const std::bad_cast &)
161 {
162 throw std::runtime_error("TrainableGraph: " + op.name() + " is not a trainable operation");
163 }
164 });
165
166 verifyTrainingUseDefs();
167}
virtual std::string name() const
Definition IOperation.h:36

References onert::util::ObjectManager< Index, Object >::iterate(), onert::ir::IOperation::name(), operations(), and onert::ir::Graph::verify().


The documentation for this class was generated from the following files: