ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::compiler::train::pass::TrainableConstantInsertionPass Class Reference

#include <TrainableConstantInsertionPass.h>

Collaboration diagram for onert::compiler::train::pass::TrainableConstantInsertionPass:

Public Member Functions

std::string id () final
 Returns string id for this pass. Same with class name.
 
void callback (const ir::OperationIndex &index, ir::IOperation &node) final
 Be called for all nodes of graph.
 
- Public Member Functions inherited from onert::compiler::pass::LoweredOperationPass
 LoweredOperationPass (ILoweredGraph &lowered_graph)
 
virtual ~LoweredOperationPass ()=default
 
- Public Member Functions inherited from onert::compiler::pass::OperationPass
virtual ~OperationPass ()=default
 
void run () final
 Run the pass.
 
 Pass (ir::Graph &graph)
 
- Public Member Functions inherited from onert::compiler::pass::Pass
 Pass (ir::Graph &graph)
 
virtual ~Pass ()=default
 
- Public Member Functions inherited from onert::compiler::pass::IPass
virtual ~IPass ()=default
 

Additional Inherited Members

- Protected Attributes inherited from onert::compiler::pass::LoweredOperationPass
ILoweredGraph_lowered_graph
 
- Protected Attributes inherited from onert::compiler::pass::Pass
ir::Graph_graph
 

Detailed Description

Definition at line 33 of file TrainableConstantInsertionPass.h.

Member Function Documentation

◆ callback()

void onert::compiler::train::pass::TrainableConstantInsertionPass::callback ( const ir::OperationIndex index,
ir::IOperation node 
)
finalvirtual

Be called for all nodes of graph.

Parameters
indexis the index of a node in graph
nodeis the node in graph

Implements onert::compiler::pass::LoweredOperationPass.

Definition at line 31 of file TrainableConstantInsertionPass.cc.

33{
34 for (const auto &input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
35 {
36 auto &object = _graph.operands().at(input);
37
38 // Skip if the operand is not constant or not shared constant
39 if (!object.isConstant() || object.getUses().size() < 2)
40 continue;
41
42 // Insert new operands for shared constant except for the current node.
43 const auto uses(object.getUses());
44 for (const auto &use_index : uses)
45 {
46 if (use_index == node_index)
47 continue;
48
49 // NOTE The PermuteFactor(backend and layout) of the current node and the use node may be
50 // different. But there is no problem because both nodes' constant operand will have
51 // only one use node.
52 const auto new_index = insertNewOperand(object);
53 updateUseDef(input, new_index, use_index);
54 }
55
56 // The input of the current node will have one use as the current node
57 assert(object.getUses().size() == 1 && object.getUses().contains(node_index));
58 }
59}
const Operands & operands() const override
Definition Graph.h:112
const Object & at(const Index &index) const
Get the object that is associated with the given index.
int32_t size[5]
Definition Slice.cpp:35

References onert::compiler::pass::Pass::_graph, onert::util::ObjectManager< Index, Object >::at(), onert::ir::DUPLICATED, onert::ir::IOperation::getInputs(), onert::ir::Graph::operands(), size, and onert::ir::UNDEFINED.

◆ id()

std::string onert::compiler::train::pass::TrainableConstantInsertionPass::id ( )
inlinefinalvirtual

Returns string id for this pass. Same with class name.

Returns
string id

Implements onert::compiler::pass::LoweredOperationPass.

Definition at line 39 of file TrainableConstantInsertionPass.h.

39{ return "TrainableConstantInsertionPass"; }

The documentation for this class was generated from the following files: