ONE - On-device Neural Engine
Loading...
Searching...
No Matches
exo::MatMulConverter Class Reference

Convert loco::MatMul to locoex::TFLFullyConnected. More...

#include <MatMulConverter.h>

Collaboration diagram for exo::MatMulConverter:

Public Member Functions

const char * name (void) const final
 
bool convert (loco::MatMul *origin) final
 Converts loco::MatMul to locoex::TFLFullyConnected.
 
- Public Member Functions inherited from exo::CanonicalNodeConverter< loco::MatMul >
bool run (loco::Graph *graph)
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Additional Inherited Members

Detailed Description

Convert loco::MatMul to locoex::TFLFullyConnected.

Definition at line 30 of file MatMulConverter.h.

Member Function Documentation

◆ convert()

bool exo::MatMulConverter::convert ( loco::MatMul origin)
finalvirtual

Converts loco::MatMul to locoex::TFLFullyConnected.

Note
Because TFLFullyConnected accepts input and weights of loco::Domain::Matrix, loco::MatrixDecode will be inserted as an input and weights to meet domain invariant.

How it works:

Before: Foo1 -— MatrixEncode -— MatMul -— MatrixDecode -— Bar Foo2 -— MatrixEncode -—/

After:

Foo1 - MatrixEncode - MatrixDecode - TFLFullyConnected - MatrixEncode - MatrixDecode - Bar Foo2 - MatrixEncode - MatrixDecode -/

Note
This method replaces MatMul with "- MatrixDecode - TFLFullyConnected - MatrixEncode -".
  • MatrixDecode -/ Redundant nodes will be removed during transforms.

https://github.com/tensorflow/tensorflow/blob/v1.13.1/tensorflow/lite/kernels/internal/reference/fully_connected.h

Implements exo::CanonicalNodeConverter< loco::MatMul >.

Definition at line 54 of file MatMulConverter.cpp.

55{
56 auto *graph = origin->graph();
57
58 assert(origin->lhs());
59 assert(origin->rhs());
60
61 auto tfl_fc = graph->nodes()->create<locoex::TFLFullyConnected>();
63
64 // let's create a new graph connection with tfl_fc
65 {
66 // input
67 auto lhs_matrix_dec = make_matrix_decode<MatrixLayout::HW>(origin->lhs());
68 tfl_fc->input(lhs_matrix_dec);
69
70 // weights (WH format on TFLite)
71 auto rhs_matrix_dec = make_matrix_decode<MatrixLayout::WH>(origin->rhs());
72 tfl_fc->weights(rhs_matrix_dec);
73
74 // bias
75 auto zero_const = graph->nodes()->create<locoex::TFLConst>();
76 { // TODO Create optimization pass which fuse additional Add into bias of Conv or FC
77 assert(loco::shape_known(origin));
78 assert(loco::dtype_known(origin) && loco::dtype_get(origin) == loco::DataType::FLOAT32);
79
80 auto output_depth = loco::shape_get(origin->rhs()).as<loco::MatrixShape>().width();
81 // TODO Fix it with type inference
82 zero_const->dtype(loco::DataType::FLOAT32);
83 zero_const->rank(1);
84 zero_const->dim(0) = output_depth;
85 zero_const->size<loco::DataType::FLOAT32>(output_depth.value());
86 for (uint32_t x = 0; x < output_depth.value(); x++)
87 zero_const->at<loco::DataType::FLOAT32>(x) = 0.0;
88 }
89 tfl_fc->bias(zero_const);
90
91 // output
92 auto matrix_enc = make_matrix_encode<MatrixLayout::HW>(tfl_fc);
93
94 // replace canonical node
95 loco::replace(origin).with(matrix_enc);
96 origin->lhs(nullptr);
97 origin->rhs(nullptr);
98 }
99
100 return true;
101}
Node * rhs(void) const
Definition Nodes.h:1073
Node * lhs(void) const
Definition Nodes.h:1070
Matrix Shape.
Definition MatrixShape.h:38
Graph * graph(void)
Definition Node.h:70
ShapeType as(void) const
void with(Node *into) const
Definition Node.cpp:66
Class to build tensor data.
Definition TFLNodes.h:198
FULLY_CONNECTED in TensorFlow Lite.
Definition TFLNodes.h:298
template loco::MatrixDecode * make_matrix_decode< MatrixLayout::WH >(loco::Node *input_for_decode)
template loco::MatrixDecode * make_matrix_decode< MatrixLayout::HW >(loco::Node *input_for_decode)
template loco::MatrixEncode * make_matrix_encode< MatrixLayout::HW >(loco::Node *input_for_encode)
bool shape_known(const Node *node)
bool dtype_known(const Node *node)
NodeShape shape_get(const Node *node)
DataType dtype_get(const Node *node)
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82

References loco::NodeShape::as(), loco::dtype_get(), loco::dtype_known(), locoex::TFLNodeMixin< TFLNodeTrait::FusedActFunc >::fusedActivationFunction(), loco::Node::graph(), loco::MatMul::lhs(), exo::make_matrix_decode< MatrixLayout::HW >(), exo::make_matrix_decode< MatrixLayout::WH >(), exo::make_matrix_encode< MatrixLayout::HW >(), locoex::NONE, loco::replace(), loco::MatMul::rhs(), loco::shape_get(), loco::shape_known(), and loco::Subst< SubstQualifier::Default >::with().

◆ name()

const char * exo::MatMulConverter::name ( void  ) const
inlinefinalvirtual

Reimplemented from exo::CanonicalNodeConverter< loco::MatMul >.

Definition at line 33 of file MatMulConverter.h.

33{ return "exo::MatMulConverter"; }

The documentation for this class was generated from the following files: