ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::ir::train::TrainableGraph Class Reference

#include <TrainableGraph.h>

Collaboration diagram for onert::ir::train::TrainableGraph:

Public Member Functions

 TrainableGraph ()
 Construct a new Trainable Graph object.
 
 TrainableGraph (const TrainableGraph &tgraph)
 
 TrainableGraph (const Graph &graph)
 
 ~TrainableGraph ()=default
 
OperandIndex addOperand (const Shape &shape, const TypeInfo &type)
 
OperandIndex addOperand (OperandIndex index, std::unique_ptr< Operand > &&operand)
 Add an operand to the graph with the given index and object.
 
OperationIndex addOperation (std::unique_ptr< ITrainableOperation > &&operation)
 Add a new trainable operation to the graph.
 
OperationIndex replaceOperation (OperationIndex index, std::unique_ptr< ITrainableOperation > &&operation)
 Replace a trainable operation which the graph already has.
 
OperandIndex addBackwardOperand (OperandIndex index, std::unique_ptr< Operand > &&bwd_operand)
 Add an operand for backwarding to the graph with the given index and object.
 
void changeShape (const OperandIndex &ind, const ir::Shape &new_shape) override
 
void changeBackwardShape (const OperandIndex &ind, const ir::Shape &new_shape)
 
void addInput (const OperandIndex &ind, const std::string &name="")
 
void addOutput (const OperandIndex &ind, const std::string &name="")
 
void addLoss (const OperandIndex &loss_ind, const IOIndex &pred_io_ind)
 
void verify () const
 
void removeOperand (const OperandIndex &ind)
 
void setInputs (OperandIndexSequence inputs, std::unordered_map< std::string, IOIndex > name_to_input)
 
void setOutputs (OperandIndexSequence outputs, std::unordered_map< std::string, IOIndex > name_to_output)
 
void enableBackward (const OperationIndex &index)
 
void disableBackward (const OperationIndex &index)
 
void setTrainingUseDefs (const UseDefChains &training_defuses)
 
const OperandIndexSequencegetInputs () const override
 
const OperandIndexSequencegetOutputs () const override
 
IOIndex getInputIndex (const std::string &name) const override
 
IOIndex getOutputIndex (const std::string &name) const override
 
const Operandsoperands () const override
 
Operandsoperands ()
 
const Operationsoperations () const override
 
const Operandsbackward_operands () const
 
OperandIndex getLossIndex (const IOIndex &pred_io_ind) const
 
const Graphgraph () const
 
const ITrainableOperationoperation (OperationIndex index) const
 
const UseDefChainstrainingUseDefs () const
 
std::vector< ir::OperationIndextopolSortOperations () const
 
std::vector< ir::OperationIndexbtopolSortOperations () const
 
std::vector< ir::OperationIndexessentialBackwardOrder () const
 
std::vector< ir::OperationIndextruncateBackwardOrder (const std::vector< ir::OperationIndex > &backward_order) const
 Truncate the backward order of operations in accordance with the alive condition whether the corresponding operation has trainable parameters.
 
std::vector< ir::OperationIndextruncateBackwardOrder (std::vector< ir::OperationIndex > backward_order, std::function< bool(const ir::OperationIndex &)> truncating_cond) const
 Truncate the backward order of operations in accordance with the given alive condition.
 
void updateGraphDependency ()
 
- Public Member Functions inherited from onert::ir::IGraph
virtual ~IGraph ()=default
 

Detailed Description

Definition at line 35 of file TrainableGraph.h.

Constructor & Destructor Documentation

◆ TrainableGraph() [1/3]

onert::ir::train::TrainableGraph::TrainableGraph ( )
explicit

Construct a new Trainable Graph object.

Parameters
graph

Definition at line 82 of file TrainableGraph.cc.

82: _graph{} {}

◆ TrainableGraph() [2/3]

onert::ir::train::TrainableGraph::TrainableGraph ( const TrainableGraph tgraph)
explicit

Definition at line 84 of file TrainableGraph.cc.

85 : _graph{tgraph._graph}, _backward_operands{tgraph._backward_operands},
86 _training_defuses{tgraph._training_defuses}, _losses{tgraph._losses}
87{
88 tgraph.operations().iterate(
89 [&](const onert::ir::OperationIndex &index, const onert::ir::IOperation &op) {
90 replaceOperation(index, dynamic_cast<const ITrainableOperation &>(op).clone());
91 });
92}
OperationIndex replaceOperation(OperationIndex index, std::unique_ptr< ITrainableOperation > &&operation)
Replace a trainable operation which the graph already has.
const Operations & operations() const override
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
std::unique_ptr< Operation > clone(const IOperation &operation)

References onert::ir::clone(), onert::util::ObjectManager< Index, Object >::iterate(), operations(), and replaceOperation().

◆ TrainableGraph() [3/3]

onert::ir::train::TrainableGraph::TrainableGraph ( const Graph graph)
explicit

Definition at line 94 of file TrainableGraph.cc.

94: _graph{graph} {}

◆ ~TrainableGraph()

onert::ir::train::TrainableGraph::~TrainableGraph ( )
default

Member Function Documentation

◆ addBackwardOperand()

OperandIndex onert::ir::train::TrainableGraph::addBackwardOperand ( OperandIndex  index,
std::unique_ptr< Operand > &&  bwd_operand 
)

Add an operand for backwarding to the graph with the given index and object.

If the given index is available, it succeeds. And bwd_operand is moved which invalidates the caller's pointer. If the given index is already taken, it fails. And bwd_operand will not be moved so the caller's pointer will be still valid.

Parameters
[in]indexIndex to be added
[in]bwd_operandBackward operand to be added
Returns
OperandIndex index if successful, UNDEFINED otherwise

Definition at line 117 of file TrainableGraph.cc.

119{
120 return _backward_operands.push(std::move(bwd_operand), index);
121}
Index push(std::unique_ptr< Object > &&object, Index index)
Put the object in the container with given index.

References onert::util::ObjectManager< Index, Object >::push().

◆ addInput()

void onert::ir::train::TrainableGraph::addInput ( const OperandIndex ind,
const std::string &  name = "" 
)

Definition at line 144 of file TrainableGraph.cc.

145{
146 _graph.addInput(ind, name);
147}
void addInput(const OperandIndex &ind, const std::string &name="")
Definition Graph.cc:123

References onert::ir::Graph::addInput().

Referenced by onert::compiler::train::pass::LossInsertionPass::run().

◆ addLoss()

void onert::ir::train::TrainableGraph::addLoss ( const OperandIndex loss_ind,
const IOIndex pred_io_ind 
)

Definition at line 369 of file TrainableGraph.cc.

370{
371 _losses.emplace(pred_ioind, loss_ind);
372}

Referenced by onert::compiler::train::pass::LossInsertionPass::run().

◆ addOperand() [1/2]

OperandIndex onert::ir::train::TrainableGraph::addOperand ( const Shape shape,
const TypeInfo type 
)

Definition at line 96 of file TrainableGraph.cc.

97{
98 return _graph.addOperand(shape, type);
99}
OperandIndex addOperand(const Shape &shape, const TypeInfo &type)
Definition Graph.cc:35

References onert::ir::Graph::addOperand().

Referenced by onert::compiler::train::pass::LossInsertionPass::run().

◆ addOperand() [2/2]

OperandIndex onert::ir::train::TrainableGraph::addOperand ( OperandIndex  index,
std::unique_ptr< Operand > &&  operand 
)

Add an operand to the graph with the given index and object.

If the given index is available, it succeeds. And operand is moved which invalidates the caller's pointer. If the given index is already taken, it fails. And operand will not be moved so the caller's pointer will be still valid.

Parameters
[in]indexIndex to be added
[in]operandOperand to be added
Returns
OperandIndex index if successful, UNDEFINED otherwise

Definition at line 101 of file TrainableGraph.cc.

102{
103 return _graph.addOperand(index, std::move(operand));
104}

References onert::ir::Graph::addOperand().

◆ addOperation()

OperationIndex onert::ir::train::TrainableGraph::addOperation ( std::unique_ptr< ITrainableOperation > &&  operation)

Add a new trainable operation to the graph.

If the given operation has at least one invalid operand index, it fails. And operation will not be moved so the caller's pointer will be still valid.

Parameters
operationOperation to be added
Returns
OperationIndex index if successful, UNDEFINED otherwise

Definition at line 106 of file TrainableGraph.cc.

107{
108 return _graph.addOperation(std::move(operation));
109}
OperationIndex addOperation(std::unique_ptr< IOperation > &&node)
Definition Graph.cc:69

References onert::ir::Graph::addOperation(), and operation().

Referenced by onert::compiler::train::pass::LossInsertionPass::run().

◆ addOutput()

void onert::ir::train::TrainableGraph::addOutput ( const OperandIndex ind,
const std::string &  name = "" 
)

Definition at line 149 of file TrainableGraph.cc.

150{
151 _graph.addOutput(ind, name);
152}
void addOutput(const OperandIndex &ind, const std::string &name="")
Definition Graph.cc:130

References onert::ir::Graph::addOutput().

◆ backward_operands()

const Operands & onert::ir::train::TrainableGraph::backward_operands ( ) const
inline

Definition at line 126 of file TrainableGraph.h.

126{ return _backward_operands; }

◆ btopolSortOperations()

std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::btopolSortOperations ( ) const

Definition at line 275 of file TrainableGraph.cc.

276{
277 std::vector<ir::OperationIndex> ret;
278 util::Set<ir::OperationIndex> unvisited;
279 ir::OperationIndex loss_idx;
280 operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) {
281 unvisited.add(index);
282 if (op.opcode() == ir::OpCode::Loss)
283 {
284 assert(!loss_idx.valid()); // Should be only one loss
285 loss_idx = index;
286 }
287 });
288
289 std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs =
290 [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void {
291 if (!unvisited.contains(index))
292 return;
293 unvisited.remove(index);
294
295 for (const auto &input : op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
296 {
297 const auto &operand = operands().at(input);
298 const auto &def = operand.getDef();
299 if (!def.valid())
300 continue;
301 dfs(def, operations().at(def));
302 }
303
304 ret.push_back(index);
305 };
306
307 dfs(loss_idx, operations().at(loss_idx));
308 std::reverse(ret.begin(), ret.end());
309 validateBackwardTopologicalOrder(ret);
310
311 return ret;
312}
OperationIndex getDef() const
Definition Operand.h:53
const OperandIndexSequence & getInputs() const override
const Operands & operands() const override
const Object & at(const Index &index) const
Get the object that is associated with the given index.
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
virtual OpCode opcode() const =0

References onert::util::Set< Element >::add(), onert::util::ObjectManager< Index, Object >::at(), onert::util::Set< Element >::contains(), onert::ir::DUPLICATED, onert::ir::Operand::getDef(), onert::ir::IOperation::getInputs(), onert::util::ObjectManager< Index, Object >::iterate(), onert::ir::IOperation::opcode(), operands(), operations(), onert::util::Set< Element >::remove(), and onert::ir::UNDEFINED.

Referenced by essentialBackwardOrder().

◆ changeBackwardShape()

void onert::ir::train::TrainableGraph::changeBackwardShape ( const OperandIndex ind,
const ir::Shape new_shape 
)

Definition at line 138 of file TrainableGraph.cc.

139{
140 assert(_backward_operands.exist(index));
141 _backward_operands.at(index).info().shape(new_shape);
142}
const OperandInfo & info(void) const
Definition Operand.h:48
const Shape & shape() const
Return tensor shape.
Definition OperandInfo.h:95
bool exist(const Index &index) const
Get the object that is associated with the given index.

References onert::util::ObjectManager< Index, Object >::at(), onert::util::ObjectManager< Index, Object >::exist(), onert::ir::Operand::info(), and onert::ir::OperandInfo::shape().

◆ changeShape()

void onert::ir::train::TrainableGraph::changeShape ( const OperandIndex ind,
const ir::Shape new_shape 
)
overridevirtual

Implements onert::ir::IGraph.

Definition at line 133 of file TrainableGraph.cc.

134{
135 _graph.changeShape(index, new_shape);
136}
void changeShape(const OperandIndex &ind, const ir::Shape &new_shape) override
Definition Graph.cc:117

References onert::ir::Graph::changeShape().

◆ disableBackward()

void onert::ir::train::TrainableGraph::disableBackward ( const OperationIndex index)

Definition at line 188 of file TrainableGraph.cc.

189{
190 auto &op = dynamic_cast<ir::train::ITrainableOperation &>(_graph.operations().at(index));
191 op.disableBackward();
192}
const Operations & operations() const override
Definition Graph.h:114

References onert::util::ObjectManager< Index, Object >::at(), onert::ir::train::ITrainableOperation::disableBackward(), and onert::ir::Graph::operations().

◆ enableBackward()

void onert::ir::train::TrainableGraph::enableBackward ( const OperationIndex index)

Definition at line 181 of file TrainableGraph.cc.

182{
183 auto op = dynamic_cast<ir::train::ITrainableOperation *>(&_graph.operations().at(index));
184 assert(op);
185 op->enableBackward();
186}

References onert::util::ObjectManager< Index, Object >::at(), and onert::ir::Graph::operations().

◆ essentialBackwardOrder()

std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::essentialBackwardOrder ( ) const

Definition at line 314 of file TrainableGraph.cc.

315{
316 auto backward_order = btopolSortOperations();
317 // get rid of all nodes not reachable from a node with trainable parameters
318 backward_order = truncateBackwardOrder(backward_order, [&](const OperationIndex &index) {
319 return operation(index).isRequiredForBackward();
320 });
321
322 return truncateBackwardOrder(backward_order);
323}
virtual bool isRequiredForBackward() const =0
const ITrainableOperation & operation(OperationIndex index) const
std::vector< ir::OperationIndex > btopolSortOperations() const
std::vector< ir::OperationIndex > truncateBackwardOrder(const std::vector< ir::OperationIndex > &backward_order) const
Truncate the backward order of operations in accordance with the alive condition whether the correspo...

References btopolSortOperations(), onert::ir::train::ITrainableOperation::isRequiredForBackward(), operation(), and truncateBackwardOrder().

Referenced by onert::backend::train::TensorPlanner::planBackPropTensors(), onert::backend::train::TensorPlanner::planDisposableBackPropTensors(), onert::backend::train::TensorPlanner::planGradientTensors(), onert::backend::train::TensorPlanner::planLayerScopeTensors(), and onert::backend::train::TensorPlanner::planNonConstTensors().

◆ getInputIndex()

IOIndex onert::ir::train::TrainableGraph::getInputIndex ( const std::string &  name) const
overridevirtual

Implements onert::ir::IGraph.

Definition at line 123 of file TrainableGraph.cc.

124{
125 return _graph.getInputIndex(name);
126}
IOIndex getInputIndex(const std::string &name) const override
Definition Graph.cc:137

References onert::ir::Graph::getInputIndex().

◆ getInputs()

const OperandIndexSequence & onert::ir::train::TrainableGraph::getInputs ( ) const
inlineoverridevirtual

Implements onert::ir::IGraph.

Definition at line 119 of file TrainableGraph.h.

119{ return _graph.getInputs(); }
const OperandIndexSequence & getInputs() const override
Definition Graph.h:106

References onert::ir::Graph::getInputs().

Referenced by onert::exec::train::TrainableExecutor::TrainableExecutor().

◆ getLossIndex()

OperandIndex onert::ir::train::TrainableGraph::getLossIndex ( const IOIndex pred_io_ind) const

Definition at line 374 of file TrainableGraph.cc.

375{
376 auto itr = _losses.find(pred_ioind);
377 return (itr == _losses.end()) ? OperandIndex{} : itr->second;
378}

Referenced by onert::exec::train::TrainableExecutor::getLoss().

◆ getOutputIndex()

IOIndex onert::ir::train::TrainableGraph::getOutputIndex ( const std::string &  name) const
overridevirtual

Implements onert::ir::IGraph.

Definition at line 128 of file TrainableGraph.cc.

129{
130 return _graph.getOutputIndex(name);
131}
IOIndex getOutputIndex(const std::string &name) const override
Definition Graph.cc:143

References onert::ir::Graph::getOutputIndex().

◆ getOutputs()

const OperandIndexSequence & onert::ir::train::TrainableGraph::getOutputs ( ) const
inlineoverridevirtual

Implements onert::ir::IGraph.

Definition at line 120 of file TrainableGraph.h.

120{ return _graph.getOutputs(); }
const OperandIndexSequence & getOutputs() const override
Definition Graph.h:108

References onert::ir::Graph::getOutputs().

Referenced by onert::compiler::train::pass::LossInsertionPass::run(), and onert::exec::train::TrainableExecutor::TrainableExecutor().

◆ graph()

◆ operands() [1/2]

Operands & onert::ir::train::TrainableGraph::operands ( )
inline

Definition at line 124 of file TrainableGraph.h.

124{ return _graph.operands(); } // TODO Remove this non-const accessor
const Operands & operands() const override
Definition Graph.h:112

References onert::ir::Graph::operands().

◆ operands() [2/2]

◆ operation()

◆ operations()

◆ removeOperand()

void onert::ir::train::TrainableGraph::removeOperand ( const OperandIndex ind)

Definition at line 173 of file TrainableGraph.cc.

173{ _graph.removeOperand(ind); }
void removeOperand(const OperandIndex &ind)
Definition Graph.h:94

References onert::ir::Graph::removeOperand().

◆ replaceOperation()

OperationIndex onert::ir::train::TrainableGraph::replaceOperation ( OperationIndex  index,
std::unique_ptr< ITrainableOperation > &&  operation 
)

Replace a trainable operation which the graph already has.

If the given index is available, it succeeds. And operation is moved which invalidates the caller's pointer. If the given operation has at least one invalid operand index, it fails. And operation will not be moved so the caller's pointer will be still valid.

No information in the graph is changed except for replacing an operation.

Parameters
operationOperation to be added
Returns
OperationIndex index if successful, UNDEFINED otherwise

Definition at line 111 of file TrainableGraph.cc.

113{
114 return _graph.replaceOperation(index, std::move(operation));
115}
OperationIndex replaceOperation(OperationIndex index, std::unique_ptr< IOperation > &&operation)
Replace an operation which the graph already has.
Definition Graph.cc:94

References operation(), and onert::ir::Graph::replaceOperation().

Referenced by TrainableGraph().

◆ setInputs()

void onert::ir::train::TrainableGraph::setInputs ( OperandIndexSequence  inputs,
std::unordered_map< std::string, IOIndex name_to_input 
)

◆ setOutputs()

void onert::ir::train::TrainableGraph::setOutputs ( OperandIndexSequence  outputs,
std::unordered_map< std::string, IOIndex name_to_output 
)

◆ setTrainingUseDefs()

void onert::ir::train::TrainableGraph::setTrainingUseDefs ( const UseDefChains training_defuses)

Definition at line 194 of file TrainableGraph.cc.

195{
196 _training_defuses.clear();
197 // TODO Replace this loop with `std::unordered_map::insert_range` since C++23
198 for (const auto &[training_index, usedef] : training_defuses)
199 {
200 _training_defuses.emplace(training_index, usedef);
201 }
202}

Referenced by updateGraphDependency().

◆ topolSortOperations()

std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::topolSortOperations ( ) const

Definition at line 267 of file TrainableGraph.cc.

268{
269 auto ret = _graph.topolSortOperations();
270 validateForwardTopologicalOrder(ret);
271
272 return ret;
273}
std::vector< ir::OperationIndex > topolSortOperations() const
Definition Graph.cc:184

References onert::ir::Graph::topolSortOperations().

Referenced by onert::backend::train::TensorPlanner::planLayerScopeTensors(), onert::backend::train::TensorPlanner::planNonConstTensors(), and onert::ir::train::UseDefGenerator::UseDefGenerator().

◆ trainingUseDefs()

◆ truncateBackwardOrder() [1/2]

std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::truncateBackwardOrder ( const std::vector< ir::OperationIndex > &  backward_order) const

Truncate the backward order of operations in accordance with the alive condition whether the corresponding operation has trainable parameters.

Parameters
backward_orderThe order of operations in a backward graph

Definition at line 360 of file TrainableGraph.cc.

361{
362 return truncateBackwardOrder(backward_order, [&](const ir::OperationIndex &index) {
363 const auto &trainable_op = operation(index);
364
365 return trainable_op.hasTrainableParameter();
366 });
367}

References operation(), and truncateBackwardOrder().

Referenced by essentialBackwardOrder(), and truncateBackwardOrder().

◆ truncateBackwardOrder() [2/2]

std::vector< ir::OperationIndex > onert::ir::train::TrainableGraph::truncateBackwardOrder ( std::vector< ir::OperationIndex backward_order,
std::function< bool(const ir::OperationIndex &)>  truncating_cond 
) const

Truncate the backward order of operations in accordance with the given alive condition.

Parameters
backward_orderThe order of operations in a backward graph
alive_condThe alive condition to stop the backward order

Definition at line 325 of file TrainableGraph.cc.

328{
329 auto forward_order = backward_order;
330 std::reverse(forward_order.begin(), forward_order.end());
331 std::set<ir::OperationIndex> alive;
332
333 for (const auto &index : forward_order)
334 {
335 if (alive_cond(index))
336 alive.insert(index);
337
338 // TODO: replace this with `std::set::contains` after C++20
339 if (alive.find(index) != alive.end())
340 {
341 const auto &op = operations().at(index);
342 for (const auto &output : op.getOutputs())
343 {
344 const auto &operand = operands().at(output);
345 for (const auto &use : operand.getUses())
346 alive.insert(use);
347 }
348 }
349 }
350
351 // TODO: replace this with `std::erase_if(std::vector)` after C++20
352 backward_order.erase(
353 std::remove_if(backward_order.begin(), backward_order.end(),
354 [&](const auto &index) { return alive.find(index) == alive.end(); }),
355 backward_order.end());
356 return backward_order;
357}
const OperandIndexSequence & getOutputs() const override

References onert::util::ObjectManager< Index, Object >::at(), onert::ir::IOperation::getOutputs(), operands(), and operations().

◆ updateGraphDependency()

void onert::ir::train::TrainableGraph::updateGraphDependency ( )

Definition at line 380 of file TrainableGraph.cc.

381{
382 _graph.verify();
383
384 // Initialize training usedefs
386
387 disableUnusedBackwardNodes(_training_defuses, *this);
388
389 verifyTrainingUseDefs();
390}
void verify(void) const
Definition Graph.cc:149
void setTrainingUseDefs(const UseDefChains &training_defuses)

References setTrainingUseDefs(), and onert::ir::Graph::verify().

◆ verify()

void onert::ir::train::TrainableGraph::verify ( void  ) const

Definition at line 154 of file TrainableGraph.cc.

155{
156 _graph.verify();
157
159 try
160 {
161 [[maybe_unused]] const auto &casted_op =
162 dynamic_cast<const onert::ir::train::ITrainableOperation &>(op);
163 }
164 catch (const std::bad_cast &)
165 {
166 throw std::runtime_error("TrainableGraph: " + op.name() + " is not a trainable operation");
167 }
168 });
169
170 verifyTrainingUseDefs();
171}
virtual std::string name() const
Definition IOperation.h:38

References onert::util::ObjectManager< Index, Object >::iterate(), onert::ir::IOperation::name(), operations(), and onert::ir::Graph::verify().


The documentation for this class was generated from the following files: