ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnc::DataFormatSwitcher Class Reference

#include <DataFormatSwitcher.h>

Collaboration diagram for nnc::DataFormatSwitcher:

Public Member Functions

 DataFormatSwitcher (mir::DataFormat target_format)
 
PassData run (PassData data) override
 run compiler pass
 
void cleanup () override
 clean compiler pass data
 
 ~DataFormatSwitcher () override
 
std::string getName () override
 
- Public Member Functions inherited from nnc::Pass
virtual ~Pass ()=default
 

Detailed Description

Definition at line 29 of file DataFormatSwitcher.h.

Constructor & Destructor Documentation

◆ DataFormatSwitcher()

nnc::DataFormatSwitcher::DataFormatSwitcher ( mir::DataFormat  target_format)
explicit

Definition at line 29 of file DataFormatSwitcher.cpp.

30 : _target_format(target_format)
31{
32}

◆ ~DataFormatSwitcher()

nnc::DataFormatSwitcher::~DataFormatSwitcher ( )
overridedefault

Member Function Documentation

◆ cleanup()

void nnc::DataFormatSwitcher::cleanup ( )
overridevirtual

clean compiler pass data

Reimplemented from nnc::Pass.

Definition at line 85 of file DataFormatSwitcher.cpp.

85{ _candidates_for_switch.clear(); }

◆ getName()

std::string nnc::DataFormatSwitcher::getName ( )
inlineoverridevirtual

Reimplemented from nnc::Pass.

Definition at line 40 of file DataFormatSwitcher.h.

40{ return "DataFormatSwitcher"; }

◆ run()

PassData nnc::DataFormatSwitcher::run ( PassData  data)
overridevirtual

run compiler pass

Parameters
data- data that pass is taken
Returns
data that can be passed to the next pass
Exceptions
PassExceptionobject if errors occured

Implements nnc::Pass.

Definition at line 36 of file DataFormatSwitcher.cpp.

37{
38 _graph = static_cast<mir::Graph *>(data);
39 assert(_graph);
40
41 // Collect nodes which use DataFormat
42 for (auto *node : _graph->getNodes())
43 {
44 switch (node->getType())
45 { // nodes using DataFormat
46 case mir::Operation::Type::avgPool2D:
47 case mir::Operation::Type::conv2D:
48 case mir::Operation::Type::deConv2D:
49 case mir::Operation::Type::depthwiseConv:
50 case mir::Operation::Type::maxPool2D:
51 _candidates_for_switch.push_back(node);
52 break;
53 default:
54 break; // not use DataFormat
55 }
56 }
57 // Switch collected ops
58 for (auto *op : _candidates_for_switch)
59 {
60 switch (op->getType())
61 {
62 case mir::Operation::Type::avgPool2D:
63 switchAvgPool2D(dynamic_cast<mir::ops::AvgPool2DOp *>(op));
64 break;
65 case mir::Operation::Type::conv2D:
66 switchConv2D(dynamic_cast<mir::ops::Conv2DOp *>(op));
67 break;
68 case mir::Operation::Type::deConv2D:
69 switchDeConv2D(dynamic_cast<mir::ops::DeConv2DOp *>(op));
70 break;
71 case mir::Operation::Type::depthwiseConv:
72 switchDepthwiseConv2D(dynamic_cast<mir::ops::DepthwiseConv2DOp *>(op));
73 break;
74 case mir::Operation::Type::maxPool2D:
75 switchMaxPool2D(dynamic_cast<mir::ops::MaxPool2DOp *>(op));
76 break;
77 default:
78 assert(false && "Can't switch DataFormat for this operation!");
79 }
80 }
81
82 return _graph;
83}

References mir::Graph::getNodes().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: