ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::compiler::train::pass::TrainableConstantInsertionPass Class Reference

#include <TrainableConstantInsertionPass.h>

Collaboration diagram for onert::compiler::train::pass::TrainableConstantInsertionPass:

Public Member Functions

std::string id () final
 Returns string id for this pass. Same with class name.
 
void callback (const ir::OperationIndex &index, ir::IOperation &node) final
 Be called for all nodes of graph.
 
- Public Member Functions inherited from onert::compiler::pass::LoweredOperationPass
 LoweredOperationPass (ILoweredGraph &lowered_graph)
 
virtual ~LoweredOperationPass ()=default
 
- Public Member Functions inherited from onert::compiler::pass::OperationPass
virtual ~OperationPass ()=default
 
void run () final
 Run the pass.
 
 Pass (ir::Graph &graph)
 
- Public Member Functions inherited from onert::compiler::pass::Pass
 Pass (ir::Graph &graph)
 
virtual ~Pass ()=default
 
- Public Member Functions inherited from onert::compiler::pass::IPass
virtual ~IPass ()=default
 

Additional Inherited Members

- Protected Attributes inherited from onert::compiler::pass::LoweredOperationPass
ILoweredGraph_lowered_graph
 
- Protected Attributes inherited from onert::compiler::pass::Pass
ir::Graph_graph
 

Detailed Description

Definition at line 27 of file TrainableConstantInsertionPass.h.

Member Function Documentation

◆ callback()

void onert::compiler::train::pass::TrainableConstantInsertionPass::callback ( const ir::OperationIndex index,
ir::IOperation node 
)
finalvirtual

Be called for all nodes of graph.

Parameters
indexis the index of a node in graph
nodeis the node in graph

Implements onert::compiler::pass::LoweredOperationPass.

Definition at line 25 of file TrainableConstantInsertionPass.cc.

27{
28 for (const auto &input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
29 {
30 auto &object = _graph.operands().at(input);
31
32 // Skip if the operand is not constant or not shared constant
33 if (!object.isConstant() || object.getUses().size() < 2)
34 continue;
35
36 // Insert new operands for shared constant except for the current node.
37 const auto uses(object.getUses());
38 for (const auto &use_index : uses)
39 {
40 if (use_index == node_index)
41 continue;
42
43 // NOTE The PermuteFactor(backend and layout) of the current node and the use node may be
44 // different. But there is no problem because both nodes' constant operand will have
45 // only one use node.
46 const auto new_index = insertNewOperand(object);
47 updateUseDef(input, new_index, use_index);
48 }
49
50 // The input of the current node will have one use as the current node
51 assert(object.getUses().size() == 1 && object.getUses().contains(node_index));
52 }
53}
const Operands & operands() const override
Definition Graph.h:110
const Object & at(const Index &index) const
Get the object that is associated with the given index.
int32_t size[5]
Definition Slice.cpp:35

References onert::compiler::pass::Pass::_graph, onert::util::ObjectManager< Index, Object >::at(), onert::ir::DUPLICATED, onert::ir::IOperation::getInputs(), onert::ir::Graph::operands(), size, and onert::ir::UNDEFINED.

◆ id()

std::string onert::compiler::train::pass::TrainableConstantInsertionPass::id ( )
inlinefinalvirtual

Returns string id for this pass. Same with class name.

Returns
string id

Implements onert::compiler::pass::LoweredOperationPass.

Definition at line 33 of file TrainableConstantInsertionPass.h.

33{ return "TrainableConstantInsertionPass"; }

The documentation for this class was generated from the following files: