ONE - On-device Neural Engine
Loading...
Searching...
No Matches
exo::Conv2DConverter Class Reference

Convert loco::Conv2D to locoex::TFLConv2D. More...

#include <Conv2DConverter.h>

Collaboration diagram for exo::Conv2DConverter:

Public Member Functions

const char * name (void) const final
 
bool convert (loco::Conv2D *origin) final
 Converts loco::Conv2D to locoex::TFLConv2D.
 
- Public Member Functions inherited from exo::CanonicalNodeConverter< loco::Conv2D >
bool run (loco::Graph *graph)
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Additional Inherited Members

Detailed Description

Convert loco::Conv2D to locoex::TFLConv2D.

Definition at line 30 of file Conv2DConverter.h.

Member Function Documentation

◆ convert()

bool exo::Conv2DConverter::convert ( loco::Conv2D origin)
finalvirtual

Converts loco::Conv2D to locoex::TFLConv2D.

Note
Because TFLConv2D accepts input and filter of loco::Domain::Tensor, loco::FeatureDecode and loco::FilterDecode will be inserted as an inputs to meet domain invariant. Please refer to the comment in AvgPool2DConvert.

Implements exo::CanonicalNodeConverter< loco::Conv2D >.

Definition at line 37 of file Conv2DConverter.cpp.

38{
39 auto *graph = origin->graph();
40
41 assert(origin->ifm());
42 assert(origin->ker());
43
44 auto tfl_conv2d = graph->nodes()->create<locoex::TFLConv2D>();
45 {
46 tfl_conv2d->stride()->w(origin->stride()->horizontal());
47 tfl_conv2d->stride()->h(origin->stride()->vertical());
48
49 auto pad = origin->pad();
50 if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0)
51 tfl_conv2d->padding(locoex::Padding::VALID);
52 else
53 // TODO This is necessary, but not sufficient condition. More rigorous check required
54 tfl_conv2d->padding(locoex::Padding::SAME);
55
56 tfl_conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE);
57 }
58
59 // let's create a new graph connection with tfl_conv2d
60 {
61 // input
62 auto feature_dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm());
63 tfl_conv2d->input(feature_dec);
64
65 // filter
66 auto filter_dec = make_filter_decode<FilterLayout::OHWI>(origin->ker());
67 tfl_conv2d->filter(filter_dec);
68
69 // bias
70 auto zero_const = graph->nodes()->create<locoex::TFLConst>();
71 {
72 assert(loco::shape_known(origin));
73 assert(loco::dtype_known(origin) && loco::dtype_get(origin) == loco::DataType::FLOAT32);
74
75 auto output_depth = loco::shape_get(origin->ker()).as<loco::FilterShape>().count();
76
77 zero_const->dtype(loco::DataType::FLOAT32);
78 zero_const->rank(1);
79 zero_const->dim(0) = output_depth;
80 zero_const->size<loco::DataType::FLOAT32>(output_depth.value());
81 for (uint32_t x = 0; x < output_depth.value(); x++)
82 zero_const->at<loco::DataType::FLOAT32>(x) = 0.0;
83 }
84 tfl_conv2d->bias(zero_const);
85
86 // output
87 auto feature_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_conv2d);
88
89 // replace canonical node
90 loco::replace(origin).with(feature_enc);
91 origin->ifm(nullptr);
92 }
93
94 return true;
95}
const Stride< 2 > * stride(void) const
Definition Nodes.h:567
Node * ker(void) const
Definition Nodes.h:559
const Padding2D * pad(void) const
Definition Nodes.h:563
Node * ifm(void) const
Definition Nodes.h:556
Filter Shape.
Definition FilterShape.h:43
Graph * graph(void)
Definition Node.h:70
ShapeType as(void) const
uint32_t horizontal(void) const
Definition Stride.h:40
uint32_t vertical(void) const
Definition Stride.h:36
void with(Node *into) const
Definition Node.cpp:66
int32_t w() const
Definition TFLNodes.h:65
Class to build tensor data.
Definition TFLNodes.h:198
CONV_2D in TensorFlow Lite.
Definition TFLNodes.h:218
const Stride * stride(void) const
Definition TFLNodes.h:233
template loco::FeatureDecode * make_feature_decode< FeatureLayout::NHWC >(loco::Node *input_for_encode)
template loco::FilterDecode * make_filter_decode< FilterLayout::OHWI >(loco::Node *input_for_decode)
template loco::FeatureEncode * make_feature_encode< FeatureLayout::NHWC >(loco::Node *input_for_encode)
bool shape_known(const Node *node)
bool dtype_known(const Node *node)
NodeShape shape_get(const Node *node)
DataType dtype_get(const Node *node)
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82

References loco::NodeShape::as(), loco::dtype_get(), loco::dtype_known(), loco::Node::graph(), loco::Stride< 2 >::horizontal(), loco::Conv2D::ifm(), loco::Conv2D::ker(), exo::make_feature_decode< FeatureLayout::NHWC >(), exo::make_feature_encode< FeatureLayout::NHWC >(), exo::make_filter_decode< FilterLayout::OHWI >(), locoex::NONE, loco::Conv2D::pad(), loco::replace(), locoex::SAME, loco::shape_get(), loco::shape_known(), locoex::TFLConv2D::stride(), loco::Conv2D::stride(), locoex::VALID, loco::Stride< 2 >::vertical(), locoex::Stride::w(), and loco::Subst< SubstQualifier::Default >::with().

◆ name()

const char * exo::Conv2DConverter::name ( void  ) const
inlinefinalvirtual

Reimplemented from exo::CanonicalNodeConverter< loco::Conv2D >.

Definition at line 33 of file Conv2DConverter.h.

33{ return "exo::Conv2DConverter"; }

The documentation for this class was generated from the following files: