ONE - On-device Neural Engine
Loading...
Searching...
No Matches
exo::TransposedConv2DConverter Class Reference

Convert loco::TransposedConv2D to locoex::TFLTransposeConv and auxiliary. More...

#include <TransposedConv2DConverter.h>

Collaboration diagram for exo::TransposedConv2DConverter:

Public Member Functions

const char * name (void) const final
 
bool convert (loco::TransposedConv2D *origin) final
 
- Public Member Functions inherited from exo::CanonicalNodeConverter< loco::TransposedConv2D >
bool run (loco::Graph *graph)
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Additional Inherited Members

Detailed Description

Convert loco::TransposedConv2D to locoex::TFLTransposeConv and auxiliary.

<BEFORE>

IFM ----— TransposedConv2D — OFM (Feature) / (Feature) / KER ---— (Filter)

<AFTER>

out_backprop : IFM ----— FeatureDecode — TFLTransposeConv — FeatureEncode — OFM [Feature] [Tensor] / / [Tensor] [Feature] / / filter: KER ----— FilterDecode — / [Filter] [Tensor] / / input_sizes : TFLConst (new) ---------— [Tensor]

Definition at line 51 of file TransposedConv2DConverter.h.

Member Function Documentation

◆ convert()

bool exo::TransposedConv2DConverter::convert ( loco::TransposedConv2D origin)
finalvirtual

Implements exo::CanonicalNodeConverter< loco::TransposedConv2D >.

Definition at line 29 of file TransposedConv2DConverter.cpp.

30{
31 // Shape is required to set origin->inputSizes()
32 if (not loco::shape_known(origin))
33 return false;
34
35 if ((origin->ifm() == nullptr) or (origin->ker() == nullptr))
36 return false;
37
38 auto *graph = origin->graph();
39
40 auto tfl_tr_conv = graph->nodes()->create<locoex::TFLTransposeConv>();
41 {
42 tfl_tr_conv->stride()->w(origin->stride()->horizontal());
43 tfl_tr_conv->stride()->h(origin->stride()->vertical());
44
45 auto pad = origin->pad();
46 if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0)
47 tfl_tr_conv->padding(locoex::Padding::VALID);
48 else
49 // TODO This is necessary, but not sufficient condition. More rigorous check required
50 tfl_tr_conv->padding(locoex::Padding::SAME);
51 }
52
53 // let's create a new graph connection with tfl_tr_conv
54 {
55 // Make inputSizes from shape of origin
56 auto input_sizes_const = graph->nodes()->create<locoex::TFLConst>();
57 auto origin_shape = loco::shape_get(origin).as<loco::FeatureShape>();
58
59 const loco::DataType S32 = loco::DataType::S32;
60
61 input_sizes_const->dtype(S32);
62 input_sizes_const->rank(1);
63 input_sizes_const->dim(0) = 4;
64 input_sizes_const->size<S32>(4);
65 // Note that NHWC is layout for inputSizes determined by tflite format
66 input_sizes_const->at<S32>(0) = origin_shape.count().value(); // N
67 input_sizes_const->at<S32>(1) = origin_shape.height().value(); // H
68 input_sizes_const->at<S32>(2) = origin_shape.width().value(); // W
69 input_sizes_const->at<S32>(3) = origin_shape.depth().value(); // C
70
71 tfl_tr_conv->inputSizes(input_sizes_const);
72
73 // filter
74 auto filter_dec = make_filter_decode<FilterLayout::OHWI>(origin->ker());
75 tfl_tr_conv->filter(filter_dec);
76
77 // outBackprop
78 auto feature_dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm());
79 tfl_tr_conv->outBackprop(feature_dec);
80
81 // output
82 auto feature_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_tr_conv);
83
84 // replace canonical node
85 loco::replace(origin).with(feature_enc);
86 origin->ifm(nullptr);
87 }
88
89 return true;
90}
Feature Map Shape.
Graph * graph(void)
Definition Node.h:70
ShapeType as(void) const
uint32_t horizontal(void) const
Definition Stride.h:40
uint32_t vertical(void) const
Definition Stride.h:36
void with(Node *into) const
Definition Node.cpp:66
Node * ifm(void) const
Definition Nodes.h:690
Node * ker(void) const
Definition Nodes.h:693
const Stride< 2 > * stride(void) const
Definition Nodes.h:701
const Padding2D * pad(void) const
Definition Nodes.h:697
int32_t w() const
Definition TFLNodes.h:65
Class to build tensor data.
Definition TFLNodes.h:198
const loco::DataTypeImpl< DT >::Type & at(uint32_t n) const
Definition TFLNodes.cpp:42
TRANSPOSE_CONV in TensorFlow Lite.
Definition TFLNodes.h:528
const Stride * stride(void) const
Definition TFLNodes.h:543
template loco::FeatureDecode * make_feature_decode< FeatureLayout::NHWC >(loco::Node *input_for_encode)
template loco::FilterDecode * make_filter_decode< FilterLayout::OHWI >(loco::Node *input_for_decode)
template loco::FeatureEncode * make_feature_encode< FeatureLayout::NHWC >(loco::Node *input_for_encode)
bool shape_known(const Node *node)
DataType
"scalar" value type
Definition DataType.h:27
NodeShape shape_get(const Node *node)
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
const loco::DataType S32

References loco::NodeShape::as(), locoex::TFLConst::at(), loco::Node::graph(), loco::Stride< 2 >::horizontal(), loco::TransposedConv2D::ifm(), loco::TransposedConv2D::ker(), exo::make_feature_decode< FeatureLayout::NHWC >(), exo::make_feature_encode< FeatureLayout::NHWC >(), exo::make_filter_decode< FilterLayout::OHWI >(), loco::TransposedConv2D::pad(), loco::replace(), locoex::SAME, loco::shape_get(), loco::shape_known(), locoex::TFLTransposeConv::stride(), loco::TransposedConv2D::stride(), locoex::VALID, loco::Stride< 2 >::vertical(), locoex::Stride::w(), and loco::Subst< SubstQualifier::Default >::with().

◆ name()

const char * exo::TransposedConv2DConverter::name ( void  ) const
inlinefinalvirtual

Reimplemented from exo::CanonicalNodeConverter< loco::TransposedConv2D >.

Definition at line 54 of file TransposedConv2DConverter.h.

54{ return "exo::TransposedConv2DConverter"; }

The documentation for this class was generated from the following files: