ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci::CopyQuantParamPass Class Reference

Pass to copy quantparam (scale, zerop) of a tensor to another tensor. More...

#include <CopyQuantParamPass.h>

Collaboration diagram for luci::CopyQuantParamPass:

Public Types

using TensorVector = std::vector< std::string >
 

Public Member Functions

 CopyQuantParamPass (TensorVector &src_tensors, TensorVector &dst_tensors)
 
virtual const char * name (void) const
 
bool run (loco::Graph *graph)
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Detailed Description

Pass to copy quantparam (scale, zerop) of a tensor to another tensor.

Definition at line 30 of file CopyQuantParamPass.h.

Member Typedef Documentation

◆ TensorVector

using luci::CopyQuantParamPass::TensorVector = std::vector<std::string>

Definition at line 33 of file CopyQuantParamPass.h.

Constructor & Destructor Documentation

◆ CopyQuantParamPass()

luci::CopyQuantParamPass::CopyQuantParamPass ( TensorVector src_tensors,
TensorVector dst_tensors 
)
inline

Definition at line 36 of file CopyQuantParamPass.h.

37 : _src_tensors{src_tensors}, _dst_tensors{dst_tensors}
38 {
39 // DO NOTHING
40 }

Member Function Documentation

◆ name()

virtual const char * luci::CopyQuantParamPass::name ( void  ) const
inlinevirtual

Reimplemented from logo::Pass.

Definition at line 41 of file CopyQuantParamPass.h.

41{ return "luci::CopyQuantParamPass"; }

Referenced by run().

◆ run()

bool luci::CopyQuantParamPass::run ( loco::Graph graph)
virtual

Run the pass.

Returns
false if there was nothing changed

Implements logo::Pass.

Definition at line 35 of file CopyQuantParamPass.cpp.

36{
37 LOGGER(l);
38
39 INFO(l) << "CopyQuantParamPass Start" << std::endl;
40
41 if (_src_tensors.size() != _dst_tensors.size())
42 throw std::runtime_error("The numbers of Source/Destination tensors do not match.");
43
44 // Return src/dst CircleNodes
45 auto get_src_dst = [&g](std::string src, std::string dst) {
46 SrcDst src_dst;
47 for (auto node : loco::active_nodes(loco::output_nodes(g)))
48 {
49 auto const cnode = loco::must_cast<CircleNode *>(node);
50 auto const name = cnode->name();
51 if (name == src)
52 src_dst.src = cnode;
53
54 if (name == dst)
55 src_dst.dst = cnode;
56 }
57 return src_dst;
58 };
59
60 for (uint32_t i = 0; i < _src_tensors.size(); i++)
61 {
62 auto &src = _src_tensors[i];
63 auto &dst = _dst_tensors[i];
64
65 auto nodes = get_src_dst(src, dst);
66 if (not nodes.src)
67 throw std::runtime_error("The tensor named " + src + " does not exist.");
68
69 if (not nodes.dst)
70 throw std::runtime_error("The tensor named " + dst + " does not exist.");
71
72 copy_quantparam(nodes.src, nodes.dst);
73
74 if (auto output = dynamic_cast<luci::CircleOutput *>(nodes.dst))
75 {
76 auto from_node = loco::must_cast<luci::CircleNode *>(output->from());
77 copy_quantparam(output, from_node);
78 }
79
80 INFO(l) << "Quantparam of " << src << " is copied to " << dst << std::endl;
81 }
82
83 INFO(l) << "CopyQuantParamPass End" << std::endl;
84
85 return false; // one time run
86}
#define LOGGER(name)
Definition Log.h:65
#define INFO(name)
Definition Log.h:68
CircleNode for Output of the Graph.
virtual const char * name(void) const
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
void copy_quantparam(const luci::CircleNode *src, luci::CircleNode *dst)
copy CircleQuantParam of src to dst

References loco::active_nodes(), luci::copy_quantparam(), INFO, LOGGER, name(), and loco::output_nodes().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: