ONE - On-device Neural Engine
Loading...
Searching...
No Matches
exo::DepthwiseConv2DConverter Class Reference

Convert loco::DepthwiseConv2D to locoex::TFLDepthwiseConv2D and auxiliary. More...

#include <DepthwiseConv2DConverter.h>

Collaboration diagram for exo::DepthwiseConv2DConverter:

Public Member Functions

const char * name (void) const final
 
bool convert (loco::DepthwiseConv2D *origin) final
 
- Public Member Functions inherited from exo::CanonicalNodeConverter< loco::DepthwiseConv2D >
bool run (loco::Graph *graph)
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Additional Inherited Members

Detailed Description

Convert loco::DepthwiseConv2D to locoex::TFLDepthwiseConv2D and auxiliary.

<BEFORE>

IFM -----— DepthwiseConv2D — Out [Feature] / [Feature] / KER ----— [DWFilter]

<AFTER> TFLConst (bias) ------------------------— \ IFM ---— FeatureDecode ---------------— TFLDepthwiseConv2D — FeatureEncode — Out [Feature] [Tensor] / [Tensor] [Feature] / KER ----— DepthwiseFilterDecode — TFLReshape [DWFilter] [Tensor / H W C M] [Tensor / 1 H W CM]

Definition at line 50 of file DepthwiseConv2DConverter.h.

Member Function Documentation

◆ convert()

bool exo::DepthwiseConv2DConverter::convert ( loco::DepthwiseConv2D origin)
finalvirtual

Implements exo::CanonicalNodeConverter< loco::DepthwiseConv2D >.

Definition at line 33 of file DepthwiseConv2DConverter.cpp.

34{
35 // Filter shape is required
36 if (not loco::shape_known(origin->ker()))
37 return false;
38
39 auto filter_shape = loco::shape_get(origin->ker()).as<loco::DepthwiseFilterShape>();
40
41 if ((origin->ifm() == nullptr) or (origin->ker() == nullptr))
42 return false;
43
44 auto *graph = origin->graph();
45
46 auto tfl_dw_conv2d = graph->nodes()->create<locoex::TFLDepthwiseConv2D>();
47 {
48 tfl_dw_conv2d->stride()->w(origin->stride()->horizontal());
49 tfl_dw_conv2d->stride()->h(origin->stride()->vertical());
50
51 auto pad = origin->pad();
52 if (pad->bottom() == 0 && pad->top() == 0 && pad->left() == 0 && pad->right() == 0)
53 tfl_dw_conv2d->padding(locoex::Padding::VALID);
54 else
55 // TODO This is necessary, but not sufficient condition. More rigorous check required
56 tfl_dw_conv2d->padding(locoex::Padding::SAME);
57
58 tfl_dw_conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE);
59
60 uint32_t multiplier = filter_shape.multiplier().value();
61 EXO_ASSERT(multiplier < static_cast<uint32_t>(std::numeric_limits<int32_t>::max()),
62 "Multiplier is too big that casting may occur unintended behavior")
63
64 tfl_dw_conv2d->depthMultiplier(static_cast<int32_t>(multiplier));
65 }
66
67 // let's create a new graph connection with tfl_dw_conv2d
68 {
69 // ifm --- feature_dec --- tfl_dw_conv2d
70 auto feature_dec = make_feature_decode<FeatureLayout::NHWC>(origin->ifm());
71 tfl_dw_conv2d->input(feature_dec);
72
73 // ker --- filter_dec(H x W x C x M) --- reshape(1 x H x W x CM) --- tfl_dw_conv2d
75
76 auto reshape = graph->nodes()->create<locoex::TFLReshape>();
77 reshape->tensor(filter_dec);
78
79 int32_t new_shape[4] = {
80 1, static_cast<int32_t>(filter_shape.height().value()),
81 static_cast<int32_t>(filter_shape.width().value()),
82 static_cast<int32_t>(filter_shape.depth().value() * filter_shape.multiplier().value())};
83 locoex::set_new_shape(reshape, new_shape, 4);
84
85 tfl_dw_conv2d->filter(reshape);
86
87 // bias
88 auto zero_const = graph->nodes()->create<locoex::TFLConst>();
89 {
90 assert(loco::shape_known(origin));
91 assert(loco::dtype_known(origin) && loco::dtype_get(origin) == loco::DataType::FLOAT32);
92
93 // bias size is C * M
94 uint32_t bias_size = filter_shape.depth().value() * filter_shape.multiplier().value();
95
96 zero_const->dtype(loco::DataType::FLOAT32);
97 zero_const->rank(1);
98 zero_const->dim(0) = bias_size;
99 zero_const->size<loco::DataType::FLOAT32>(bias_size);
100 for (uint32_t x = 0; x < bias_size; x++)
101 zero_const->at<loco::DataType::FLOAT32>(x) = 0.0;
102 }
103 tfl_dw_conv2d->bias(zero_const);
104
105 // output
106 auto feature_enc = make_feature_encode<FeatureLayout::NHWC>(tfl_dw_conv2d);
107
108 // replace canonical node
109 loco::replace(origin).with(feature_enc);
110 origin->ifm(nullptr);
111 }
112
113 return true;
114}
Node * ifm(void) const
Definition Nodes.h:584
const Stride< 2 > * stride(void) const
Definition Nodes.h:595
const Padding2D * pad(void) const
Definition Nodes.h:591
Node * ker(void) const
Definition Nodes.h:587
Graph * graph(void)
Definition Node.h:70
ShapeType as(void) const
uint32_t horizontal(void) const
Definition Stride.h:40
uint32_t vertical(void) const
Definition Stride.h:36
void with(Node *into) const
Definition Node.cpp:66
int32_t w() const
Definition TFLNodes.h:65
Class to build tensor data.
Definition TFLNodes.h:198
DEPTHWISE_CONV_2D in TensorFlow Lite.
Definition TFLNodes.h:248
const Stride * stride(void) const
Definition TFLNodes.h:263
#define EXO_ASSERT(condition, msg)
Definition Check.h:28
template loco::DepthwiseFilterDecode * make_dw_filter_decode< DepthwiseFilterLayout::HWCM >(loco::Node *input_for_decode)
template loco::FeatureDecode * make_feature_decode< FeatureLayout::NHWC >(loco::Node *input_for_encode)
template loco::FeatureEncode * make_feature_encode< FeatureLayout::NHWC >(loco::Node *input_for_encode)
bool shape_known(const Node *node)
bool dtype_known(const Node *node)
NodeShape shape_get(const Node *node)
DataType dtype_get(const Node *node)
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
void set_new_shape(locoex::TFLReshape *node, int32_t *base, uint32_t size)
Set both TFLReshape's 2nd input as TFLConst, and newShape attribute with same value.
Definition TFLNodes.cpp:67

References loco::NodeShape::as(), loco::dtype_get(), loco::dtype_known(), EXO_ASSERT, loco::Node::graph(), loco::Stride< 2 >::horizontal(), loco::DepthwiseConv2D::ifm(), loco::DepthwiseConv2D::ker(), exo::make_dw_filter_decode< DepthwiseFilterLayout::HWCM >(), exo::make_feature_decode< FeatureLayout::NHWC >(), exo::make_feature_encode< FeatureLayout::NHWC >(), locoex::NONE, loco::DepthwiseConv2D::pad(), loco::replace(), locoex::SAME, locoex::set_new_shape(), loco::shape_get(), loco::shape_known(), locoex::TFLDepthwiseConv2D::stride(), loco::DepthwiseConv2D::stride(), locoex::VALID, loco::Stride< 2 >::vertical(), locoex::Stride::w(), and loco::Subst< SubstQualifier::Default >::with().

◆ name()

const char * exo::DepthwiseConv2DConverter::name ( void  ) const
inlinefinalvirtual

Reimplemented from exo::CanonicalNodeConverter< loco::DepthwiseConv2D >.

Definition at line 53 of file DepthwiseConv2DConverter.h.

53{ return "exo::DepthwiseConv2DConverter"; }

The documentation for this class was generated from the following files: