ONE - On-device Neural Engine
Loading...
Searching...
No Matches
exo::TensorTransposeConverter Class Reference

Convert loco::TensorTranspose to locoex::TFLTranspose. More...

#include <TensorTransposeConverter.h>

Collaboration diagram for exo::TensorTransposeConverter:

Public Member Functions

const char * name (void) const final
 
bool convert (loco::TensorTranspose *origin) final
 Converts loco::TensorTranspose to locoex::TFLTranspose.
 
- Public Member Functions inherited from exo::CanonicalNodeConverter< loco::TensorTranspose >
bool run (loco::Graph *graph)
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Additional Inherited Members

Detailed Description

Convert loco::TensorTranspose to locoex::TFLTranspose.

Definition at line 30 of file TensorTransposeConverter.h.

Member Function Documentation

◆ convert()

bool exo::TensorTransposeConverter::convert ( loco::TensorTranspose origin)
finalvirtual

Converts loco::TensorTranspose to locoex::TFLTranspose.

Implements exo::CanonicalNodeConverter< loco::TensorTranspose >.

Definition at line 58 of file TensorTransposeConverter.cpp.

59{
60 auto *graph = origin->graph();
61
62 auto tfl_transpose = graph->nodes()->create<locoex::TFLTranspose>();
63 {
64 // validation
65 {
66 assert(origin->input() != nullptr);
67
68 auto input_rank = loco::shape_get(origin->input()).as<loco::TensorShape>().rank();
69 if (input_rank != origin->perm()->size())
70 INTERNAL_EXN_V("perm size should be same with input rank",
71 oops::to_uint32(origin->perm()->size()));
72
73 validate_perm(origin);
74 }
75
76 tfl_transpose->a(origin->input());
77
78 // perm : set TFLConst
79 auto perm_const = graph->nodes()->create<locoex::TFLConst>();
80 {
81 perm_const->dtype(loco::DataType::S32);
82 perm_const->rank(1);
83 perm_const->dim(0) = origin->perm()->size();
84 perm_const->size<loco::DataType::S32>(origin->perm()->size());
85
86 // add perm values into perm TFLConst
87 for (loco::TensorAxis x = 0; x < origin->perm()->size(); x++)
88 {
89 perm_const->at<loco::DataType::S32>(x) = origin->perm()->axis(x);
90 }
91 }
92 tfl_transpose->perm(perm_const);
93 }
94
95 // replace canonical node
96 loco::replace(origin).with(tfl_transpose);
97 origin->input(nullptr);
98
99 return true;
100}
#define INTERNAL_EXN_V(msg, val)
@ brief throw internal exception with message and value
Definition InternalExn.h:28
Graph * graph(void)
Definition Node.h:70
ShapeType as(void) const
void with(Node *into) const
Definition Node.cpp:66
const TensorAxis & axis(TensorAxis n) const
Definition Nodes.h:1107
uint32_t size() const
Definition Nodes.h:1104
Perm * perm(void)
Definition Nodes.h:1114
Node * input(void) const
Definition Nodes.h:1095
Class to build tensor data.
Definition TFLNodes.h:198
TRANSPOSE in TensorFlow Lite.
Definition TFLNodes.h:506
uint32_t TensorAxis
Definition TensorAxis.h:25
NodeShape shape_get(const Node *node)
Subst< SubstQualifier::Default > replace(Node *node)
Definition Node.cpp:82
uint32_t to_uint32(T a)
Definition InternalExn.h:33

References loco::NodeShape::as(), loco::TensorTranspose::Perm::axis(), loco::Node::graph(), loco::TensorTranspose::input(), INTERNAL_EXN_V, loco::TensorTranspose::perm(), loco::replace(), loco::shape_get(), loco::TensorTranspose::Perm::size(), oops::to_uint32(), and loco::Subst< SubstQualifier::Default >::with().

◆ name()

const char * exo::TensorTransposeConverter::name ( void  ) const
inlinefinalvirtual

Reimplemented from exo::CanonicalNodeConverter< loco::TensorTranspose >.

Definition at line 33 of file TensorTransposeConverter.h.

33{ return "exo::TensorTransposeConverter"; }

The documentation for this class was generated from the following files: