ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci::RequantizePass Class Reference

Pass to re-quantize graph (ex: int8 -> uint8) More...

#include <RequantizePass.h>

Collaboration diagram for luci::RequantizePass:

Public Member Functions

 RequantizePass (loco::DataType input_dtype, loco::DataType output_dtype)
 
virtual const char * name (void) const
 
bool run (loco::Graph *graph)
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Detailed Description

Pass to re-quantize graph (ex: int8 -> uint8)

Definition at line 32 of file RequantizePass.h.

Constructor & Destructor Documentation

◆ RequantizePass()

luci::RequantizePass::RequantizePass ( loco::DataType  input_dtype,
loco::DataType  output_dtype 
)
inline

Definition at line 35 of file RequantizePass.h.

36 : _input_dtype{input_dtype}, _output_dtype{output_dtype}
37 {
38 // DO NOTHING
39 }

Member Function Documentation

◆ name()

virtual const char * luci::RequantizePass::name ( void  ) const
inlinevirtual

Reimplemented from logo::Pass.

Definition at line 40 of file RequantizePass.h.

40{ return "luci::RequantizePass"; }

◆ run()

bool luci::RequantizePass::run ( loco::Graph graph)
virtual

Run the pass.

Returns
false if there was nothing changed

Implements logo::Pass.

Definition at line 126 of file RequantizePass.cpp.

127{
128 LOGGER(l);
129 INFO(l) << "RequantizePass Start" << std::endl;
130
131 // Input: int8 model
132 // Output: uint8 model
133 if (_input_dtype == loco::DataType::S8 and _output_dtype == loco::DataType::U8)
134 {
135 for (auto node : loco::active_nodes(loco::output_nodes(g)))
136 {
137 RequantizeS8ToU8 rq;
138 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
139 circle_node->accept(&rq);
140 }
141 }
142 else
143 {
144 // Ignore other cases
145 return false;
146 }
147
148 // Fix wrong quantized_dimension
149 for (auto node : loco::active_nodes(loco::output_nodes(g)))
150 {
151 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
152
153 auto qparam = circle_node->quantparam();
154 if (not qparam)
155 continue;
156
157 if (circle_node->rank() != 1)
158 continue;
159
160 if (qparam->quantized_dimension == 0)
161 continue;
162
163 // For rank 1 node, quantized_dimension should be 0
164 qparam->quantized_dimension = 0;
165 WARN(l) << "Wrong quantized_dimension is fixed (" << circle_node->name() << ")" << std::endl;
166 }
167
168 // Update output dtype
169 auto graph_outputs = g->outputs();
170 for (auto node : loco::output_nodes(g))
171 {
172 auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
173 auto from_node = loco::must_cast<luci::CircleNode *>(circle_node->from());
174 if (from_node->dtype() == _output_dtype)
175 {
176 circle_node->dtype(_output_dtype);
177 auto graph_output = graph_outputs->at(circle_node->index());
178 graph_output->dtype(_output_dtype);
179 }
180 }
181
182 INFO(l) << "RequantizePass End" << std::endl;
183 return false; // one time run
184}
#define LOGGER(name)
Definition Log.h:65
#define INFO(name)
Definition Log.h:68
#define WARN(name)
Definition Log.h:70
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101

References loco::active_nodes(), INFO, LOGGER, loco::output_nodes(), and WARN.

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: