ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::compiler::pass::OddOutputPass Class Reference

Pass to specially handle odd outputs in a subgraph. More...

#include <OddOutputPass.h>

Collaboration diagram for onert::compiler::pass::OddOutputPass:

Public Member Functions

std::string id () final
 
void run () override
 
 Pass (ir::Graph &graph)
 
- Public Member Functions inherited from onert::compiler::pass::Pass
 Pass (ir::Graph &graph)
 
virtual ~Pass ()=default
 
- Public Member Functions inherited from onert::compiler::pass::IPass
virtual ~IPass ()=default
 

Additional Inherited Members

- Protected Attributes inherited from onert::compiler::pass::Pass
ir::Graph_graph
 

Detailed Description

Pass to specially handle odd outputs in a subgraph.

Runtime Graph IR requires every input or output must have distinct tensor index, this is onert's restriction. However we allow duplication of indices in the models(or API). So we should transform the graph after model-loading.

This is necessary since our API lets users to set different buffers for each input and output so it is unavoidable that we must copy the value at runtime.

Note that this is a mandatory pass for Graph.

Case 1 : An operand which is a model output and a model input

Create an operand and insert a Permute(copy) op between them. And change the output to be the newly generated operand.

e.g.)

((#0 Input0 and also Output0))
becomes
((#0 Input0)) -> [#0 Permute] -> ((#1 Output0))

Case 2 : Two or more duplicated outputs

Do the same with Case 1, but between two outputs of the same tensor index.

e.g.)

((#0 Input0)) -> [#0 Some Operation] -> ((#1 Output0 and also Output1))
becomes
((#0 Input0)) -> [#0 Some Operation] -> ((#1 Output0)) [#1 Permute] -> ((#2 Output1))

Definition at line 70 of file OddOutputPass.h.

Member Function Documentation

◆ id()

std::string onert::compiler::pass::OddOutputPass::id ( )
inlinefinalvirtual

Implements onert::compiler::pass::Pass.

Definition at line 76 of file OddOutputPass.h.

76{ return "OddOutputPass"; }

◆ Pass()

onert::compiler::pass::Pass::Pass ( ir::Graph graph)
inline

Definition at line 42 of file Pass.h.

◆ run()

void onert::compiler::pass::OddOutputPass::run ( )
overridevirtual

Implements onert::compiler::pass::Pass.

Definition at line 30 of file OddOutputPass.cc.

31{
32 auto &outputs = _graph.getOutputs();
33
34 VERBOSE(OddOutputPass) << "Case 1 : An operand which is a model output and a model input"
35 << std::endl;
36 for (const auto &ind : outputs)
37 {
38 if (_graph.getInputs().contains(ind))
39 {
40 auto permute_output_ind = insertPermute(ind);
41 // Update the output to be newly added operand
42 _graph.getOutputs().replace(ind, permute_output_ind);
43 }
44 }
45
46 VERBOSE(OddOutputPass) << "Case 2 : Two or more duplicated outputs" << std::endl;
47 std::unordered_set<ir::OperandIndex> occurence;
48 for (auto &&ind : outputs)
49 {
50 if (occurence.count(ind) == 0)
51 {
52 occurence.insert(ind);
53 continue;
54 }
55
56 // Panic when it is const, it must have been handled earlier in another pass
57 [[maybe_unused]] auto &obj = _graph.operands().at(ind);
58 assert(!obj.isConstant());
59
60 auto permute_output_ind = insertPermute(ind);
61 ind = permute_output_ind; // Replace output index to fix output duplication
62 }
63}
const Operands & operands() const override
Definition Graph.h:112
const OperandIndexSequence & getInputs() const override
Definition Graph.h:106
const OperandIndexSequence & getOutputs() const override
Definition Graph.h:108
bool contains(const OperandIndex &index) const
void replace(const OperandIndex &from, const OperandIndex &to)
const Object & at(const Index &index) const
Get the object that is associated with the given index.
#define VERBOSE(name, lv)
Definition Log.h:71

References onert::compiler::pass::Pass::_graph, onert::util::ObjectManager< Index, Object >::at(), onert::ir::OperandIndexSequence::contains(), onert::ir::Graph::getInputs(), onert::ir::Graph::getOutputs(), onert::ir::Graph::operands(), onert::ir::OperandIndexSequence::replace(), and VERBOSE.

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: