ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci::QuantizeDequantizeWeightsWithGPTQPass Class Reference

Pass to quantize weights with GPTQ algorithm. More...

#include <QuantizeDequantizeWeightsWithGPTQPass.h>

Collaboration diagram for luci::QuantizeDequantizeWeightsWithGPTQPass:

Data Structures

struct  Context
 

Public Member Functions

 QuantizeDequantizeWeightsWithGPTQPass (std::unique_ptr< Context > &&ctx, HessianMap *hessian_map)
 
virtual const char * name (void) const
 
bool run (loco::Graph *graph)
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Detailed Description

Pass to quantize weights with GPTQ algorithm.

Definition at line 36 of file QuantizeDequantizeWeightsWithGPTQPass.h.

Constructor & Destructor Documentation

◆ QuantizeDequantizeWeightsWithGPTQPass()

luci::QuantizeDequantizeWeightsWithGPTQPass::QuantizeDequantizeWeightsWithGPTQPass ( std::unique_ptr< Context > &&  ctx,
HessianMap hessian_map 
)
inline

Definition at line 48 of file QuantizeDequantizeWeightsWithGPTQPass.h.

49 : _ctx{std::move(ctx)}, _hessian_map{hessian_map}
50 {
51 // DO NOTHING
52 }

Member Function Documentation

◆ name()

virtual const char * luci::QuantizeDequantizeWeightsWithGPTQPass::name ( void  ) const
inlinevirtual

Reimplemented from logo::Pass.

Definition at line 53 of file QuantizeDequantizeWeightsWithGPTQPass.h.

53{ return "luci::QuantizeDequantizeWeightsWithGPTQPass"; }

◆ run()

bool luci::QuantizeDequantizeWeightsWithGPTQPass::run ( loco::Graph graph)
virtual

Run the pass.

Returns
false if there was nothing changed

Implements logo::Pass.

Definition at line 149 of file QuantizeDequantizeWeightsWithGPTQPass.cpp.

150{
151 LOGGER(l);
152 INFO(l) << "QuantizeDequantizeWeightsWithGPTQ Start" << std::endl;
153
154 if (_ctx->input_model_dtype != loco::DataType::FLOAT32)
155 throw std::runtime_error("GPTQPass: Weights-only quantization supports float32 input only");
156
157 if (_ctx->output_model_dtype != loco::DataType::U8 &&
158 _ctx->output_model_dtype != loco::DataType::U4)
159 {
160 throw std::runtime_error("GPTQPass: GPTQ quantization supports uint4/uint8");
161 }
162
163 auto info_by_name = layer_info_map(g, _ctx->layers_info);
164
165 auto quantize_dtype = [&](const luci::CircleNode *node) {
166 auto iter = info_by_name.find(node->name());
167
168 // Return designated quantization dtype
169 if (iter != info_by_name.end())
170 return iter->second.dtype;
171
172 // Return default quantization dtype
173 return _ctx->output_model_dtype;
174 };
175
176 auto quantize_granularity = [&](const luci::CircleNode *node) {
177 auto iter = info_by_name.find(node->name());
178
179 // Return designated quantization granularity
180 if (iter != info_by_name.end())
181 return iter->second.granularity;
182
183 // Return default quantization granularity
184 return _ctx->granularity;
185 };
186
187 // Quantize weights
188 for (auto node : loco::active_nodes(loco::output_nodes(g)))
189 {
190 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
191 QuantizeDequantizeWeightsWithGPTQ qw(_ctx->input_model_dtype, quantize_dtype(circle_node),
192 quantize_granularity(circle_node), _hessian_map);
193 circle_node->accept(&qw);
194 }
195
196 INFO(l) << "QuantizeDequantizeWeightsWithGPTQ End" << std::endl;
197 return false; // one time run
198}
#define LOGGER(name)
Definition Log.h:65
#define INFO(name)
Definition Log.h:68
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
LayerInfoMap layer_info_map(loco::Graph *g, std::vector< LayerInfo > &layers_info)

References loco::active_nodes(), INFO, luci::layer_info_map(), LOGGER, luci::CircleNode::name(), and loco::output_nodes().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: