ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnc::SinkRelu Class Reference

This pass sinks relu below MaxPooling and Concat nodes. More...

#include <SinkRelu.h>

Collaboration diagram for nnc::SinkRelu:

Public Member Functions

PassData run (PassData data) override
 run compiler pass
 
std::string getName () override
 
- Public Member Functions inherited from nnc::Pass
virtual void cleanup ()
 clean compiler pass data
 
virtual ~Pass ()=default
 

Detailed Description

This pass sinks relu below MaxPooling and Concat nodes.

Definition at line 29 of file SinkRelu.h.

Member Function Documentation

◆ getName()

std::string nnc::SinkRelu::getName ( )
inlineoverridevirtual

Reimplemented from nnc::Pass.

Definition at line 34 of file SinkRelu.h.

34{ return "SinkRelu"; };

◆ run()

PassData nnc::SinkRelu::run ( PassData  data)
overridevirtual

run compiler pass

Parameters
data- data that pass is taken
Returns
data that can be passed to the next pass
Exceptions
PassExceptionobject if errors occured

Implements nnc::Pass.

Definition at line 33 of file SinkRelu.cpp.

34{
35 auto g = static_cast<Graph *>(data);
36 assert(g);
37 GraphPatternMatcher matcher(g);
38 auto is_relu = [](const Operation *op) { return op->getType() == Operation::Type::ReLU; };
39 auto is_concat = [](const Operation *op) { return op->getType() == Operation::Type::concat; };
40 auto is_max_pool = [](const Operation *op) {
41 return op->getType() == Operation::Type::maxPool2D;
42 };
43 std::vector<std::pair<Operation *, Operation *>> matches;
44
45 // sink ReLU through MaxPool
46 matches = matcher.matchEdge(is_relu, is_max_pool);
47 for (auto pair : matches)
48 {
49 swapAdjacent(g, pair.first, pair.second);
50 }
51 // sink ReLU through Concat
52 auto matches_v = matcher.matchUpBush(is_relu, is_concat);
53 for (const auto &pair : matches_v)
54 {
55 auto relus = pair.first;
56 auto *concat = dynamic_cast<ops::ConcatOp *>(pair.second);
57 std::vector<Operation::Output *> pre_relu;
58 pre_relu.reserve(relus.size());
59 for (auto *r : relus)
60 {
61 pre_relu.emplace_back(r->getInput(0));
62 }
63 // create replacement nodes
64 auto new_concat = g->create<ops::ConcatOp>(pre_relu, concat->getAxis());
65 auto new_relu = g->create<ops::ReluOp>(new_concat->getOutput(0));
66
67 // concat is deleted here
68 g->replaceNode(concat, new_relu);
69 for (auto r : relus)
70 {
72 }
73 }
74 return g;
75}
Description of tensor concatenation operation.
Definition ConcatOp.h:31
void concat(std::ostream &os, const std::string &sep, It beg, It end)
Definition String.h:31
void swapAdjacent(mir::Graph *g, mir::Operation *top, mir::Operation *bottom)
Swap adjacent nodes in Graph. Creates new nodes and replaces the old ones with new.
void removeNodeIfUnused(mir::Graph *g, mir::Operation *op)

References mir::GraphPatternMatcher::matchEdge(), mir::GraphPatternMatcher::matchUpBush(), nnc::opt_util::removeNodeIfUnused(), and nnc::opt_util::swapAdjacent().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: