ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnc::ConstantFoldTranspose Class Reference

#include <ConstantFoldTranspose.h>

Collaboration diagram for nnc::ConstantFoldTranspose:

Public Member Functions

PassData run (PassData data) override
 run compiler pass
 
std::string getName () override
 
- Public Member Functions inherited from nnc::Pass
virtual void cleanup ()
 clean compiler pass data
 
virtual ~Pass ()=default
 

Detailed Description

Definition at line 25 of file ConstantFoldTranspose.h.

Member Function Documentation

◆ getName()

std::string nnc::ConstantFoldTranspose::getName ( )
inlineoverridevirtual

Reimplemented from nnc::Pass.

Definition at line 30 of file ConstantFoldTranspose.h.

31 {
32 static const std::string name("opt_constant_fold_transpose");
33 return name;
34 };

◆ run()

PassData ConstantFoldTranspose::run ( PassData  data)
overridevirtual

run compiler pass

Parameters
data- data that pass is taken
Returns
data that can be passed to the next pass
Exceptions
PassExceptionobject if errors occured

Implements nnc::Pass.

Definition at line 52 of file ConstantFoldTranspose.cpp.

53{
54 auto graph = static_cast<Graph *>(data);
55
56 GraphPatternMatcher matcher(graph);
57 auto is_constant = [](const Operation *op) { return op->getType() == Operation::Type::constant; };
58 auto is_transpose = [](const Operation *op) {
59 return op->getType() == Operation::Type::transpose;
60 };
61
62 auto matches = matcher.matchEdge(is_constant, is_transpose);
63 while (!matches.empty())
64 {
65 for (const auto &match : matches)
66 {
67 auto constant_op = dynamic_cast<ops::ConstantOp *>(match.first);
68 auto transpose_op = dynamic_cast<ops::TransposeOp *>(match.second);
69
70 const auto elem_type = constant_op->getValue().getElementType();
71 const auto &out_shape = transpose_op->getOutputShape(0);
72 TensorType res_type(elem_type, out_shape);
73 if (constant_op->getOutput(0)->getType().isQuantized())
74 res_type.setQuantization(constant_op->getOutput(0)->getType().getQuantization());
75
76 TensorVariant res(res_type);
77 transpose(constant_op->getValue(), res, transpose_op->getAxisOrder());
78
79 auto new_op = graph->create<ops::ConstantOp>(res);
80
81 graph->replaceNode(transpose_op, new_op);
82 opt_util::removeNodeIfUnused(graph, constant_op);
83 }
84 matches = matcher.matchEdge(is_constant, is_transpose);
85 }
86 return graph;
87}
const Shape & getOutputShape(std::size_t index) const
Definition Operation.h:163
Tensor transpose operation.
Definition TransposeOp.h:34
void removeNodeIfUnused(mir::Graph *g, mir::Operation *op)

References mir::Operation::getOutputShape(), mir::GraphPatternMatcher::matchEdge(), nnc::opt_util::removeNodeIfUnused(), and mir::TensorType::setQuantization().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: