ONE - On-device Neural Engine
Loading...
Searching...
No Matches
exo::DomainConverter< CanonicalT, TFLT > Class Template Reference

Class to handle domain conversion while converting a canonical node to TFL node(s) More...

#include <GraphBlock.h>

Public Member Functions

template<FeatureLayout FeatureLayoutT>
TFLT * convert (CanonicalT *origin, InputHandler< CanonicalT, TFLT > &input_handler)
 Performs domain conversion.
 

Detailed Description

template<class CanonicalT, class TFLT>
class exo::DomainConverter< CanonicalT, TFLT >

Class to handle domain conversion while converting a canonical node to TFL node(s)

Definition at line 134 of file GraphBlock.h.

Member Function Documentation

◆ convert()

template<class CanonicalT , class TFLT >
template<FeatureLayout FeatureLayoutT>
TFLT * exo::DomainConverter< CanonicalT, TFLT >::convert ( CanonicalT *  origin,
InputHandler< CanonicalT, TFLT > &  input_handler 
)

Performs domain conversion.

  1. if origin belong to loco::Domain::Tensor, and replace origin to a TFL node.
  2. if origin belong to loco::Domain::Feature, insert loco::FeatureDecode for input(s) and insert loco::FeatureEncode for output. Then replace origin to a TFL node.
Returns
new TFL node; nullptr if shape of origin cannot be known

Definition at line 152 of file GraphBlock.h.

154{
155 static_assert(FeatureLayoutT == FeatureLayout::NHWC, "Feature layout should be NHWC");
156
157 if (!loco::shape_known(origin))
158 {
159 return nullptr;
160 }
161
162 auto tfl_node = origin->graph()->nodes()->template create<TFLT>();
163
164 // when the input is Tensor, just replace canonical node to TFL node.
165 if (loco::shape_get(origin).domain() == loco::Domain::Tensor)
166 {
167 input_handler.handover(origin, tfl_node);
168
169 loco::replace(origin).with(tfl_node);
170 input_handler.nullify(origin);
171
172 return tfl_node;
173 }
174 else if (loco::shape_get(origin).domain() == loco::Domain::Feature)
175 {
176 std::vector<loco::Node *> feature_decodes;
177
178 for (auto input : input_handler.getInputsToConvert(origin))
179 {
180 auto dec = make_feature_decode<FeatureLayoutT>(input);
181 feature_decodes.emplace_back(dec);
182 }
183
184 input_handler.set(tfl_node, feature_decodes);
185
186 auto enc = make_feature_encode<FeatureLayoutT>(tfl_node);
187
188 loco::replace(origin).with(enc);
189 input_handler.nullify(origin);
190
191 return tfl_node;
192 }
193 else
194 INTERNAL_EXN_V("Unsupported loco::Domain", oops::to_uint32(loco::shape_get(origin).domain()));
195}
#define INTERNAL_EXN_V(msg, val)
@ brief throw internal exception with message and value
Definition InternalExn.h:28
void with(Node *into) const
Definition Node.cpp:66
bool shape_known(const Node *node)
NodeShape shape_get(const Node *node)
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
uint32_t to_uint32(T a)
Definition InternalExn.h:33

References loco::Feature, exo::InputHandler< CanonicalT, TFLT >::getInputsToConvert(), exo::InputHandler< CanonicalT, TFLT >::handover(), INTERNAL_EXN_V, exo::NHWC, exo::InputHandler< CanonicalT, TFLT >::nullify(), loco::replace(), exo::InputHandler< CanonicalT, TFLT >::set(), loco::shape_get(), loco::shape_known(), loco::Tensor, oops::to_uint32(), and loco::Subst< SubstQualifier::Default >::with().

Referenced by exo::EltwiseMaxConverter::convert(), exo::EltwiseSqrtConverter::convert(), exo::ReluConverter::convert(), and exo::Relu6Converter::convert().


The documentation for this class was generated from the following file: