ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mir::GraphPatternMatcher Class Reference

#include <GraphPatternMatcher.h>

Public Types

using Predicate = bool(const Operation *)
 

Public Member Functions

 GraphPatternMatcher (Graph *g)
 
std::vector< std::pair< Operation *, Operation * > > matchEdge (Predicate p1, Predicate p2)
 Match an edge with 2 predicates for ends of the edge.
 
std::vector< std::pair< std::vector< Operation * >, Operation * > > matchUpBush (Predicate p1, Predicate p2)
 Match a two level tree where the bottommost node has multiple previous nodes.
 

Detailed Description

Definition at line 27 of file GraphPatternMatcher.h.

Member Typedef Documentation

◆ Predicate

Definition at line 30 of file GraphPatternMatcher.h.

Constructor & Destructor Documentation

◆ GraphPatternMatcher()

mir::GraphPatternMatcher::GraphPatternMatcher ( Graph g)
inlineexplicit

Definition at line 31 of file GraphPatternMatcher.h.

31: _g(g){};

Member Function Documentation

◆ matchEdge()

std::vector< std::pair< Operation *, Operation * > > mir::GraphPatternMatcher::matchEdge ( GraphPatternMatcher::Predicate  p1,
GraphPatternMatcher::Predicate  p2 
)

Match an edge with 2 predicates for ends of the edge.

Parameters
pattern
Returns
Vector of topmost ops of all matches; empty if no mathces are found

Definition at line 26 of file GraphPatternMatcher.cpp.

27{
28
29 std::vector<std::pair<Operation *, Operation *>> matches;
30 for (auto *start : _g->getNodes())
31 {
32 if (p1(start))
33 {
34 for (auto &out : start->getOutputs())
35 {
36 for (auto use : out.getUses())
37 {
38 Operation *end = use.getNode();
39 if (p2(end))
40 {
41 matches.emplace_back(std::make_pair(start, end));
42 break;
43 }
44 }
45 }
46 }
47 }
48 return matches;
49}
ShapeIterator end(const Shape &s)

References mir::Graph::getNodes().

Referenced by nnc::CombineTransposes::run(), nnc::ConstantFoldTranspose::run(), nnc::SinkRelu::run(), and nnc::SinkTranspose::run().

◆ matchUpBush()

std::vector< std::pair< std::vector< Operation * >, Operation * > > mir::GraphPatternMatcher::matchUpBush ( Predicate  p1,
Predicate  p2 
)

Match a two level tree where the bottommost node has multiple previous nodes.

Parameters
p1Predicate for top node
p2Predicate for bottom node
Returns
Vector of pairs : all matches; empty if no matches are found

Definition at line 52 of file GraphPatternMatcher.cpp.

54{
55 std::vector<std::pair<std::vector<Operation *>, Operation *>> matches;
56 for (auto *root : _g->getNodes())
57 {
58 if (p2(root))
59 {
60 const auto &inputs = root->getInputs();
61 if (std::all_of(inputs.begin(), inputs.end(),
62 [p1](const Operation::Output *input) { return p1(input->getNode()); }))
63 {
64 std::vector<Operation *> tops;
65 tops.reserve(inputs.size());
66 for (Operation::Output *pr : inputs)
67 {
68 tops.emplace_back(pr->getNode());
69 }
70 matches.emplace_back(std::make_pair(tops, root));
71 }
72 }
73 }
74 return matches;
75}
Op * root(Op *)
Return the root Op from a given Op node.
Definition Op.cpp:144

References mir::Graph::getNodes().

Referenced by nnc::SinkRelu::run(), and nnc::SinkTranspose::run().


The documentation for this class was generated from the following files: