ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci::QuantizeWeightsPass Class Reference

Pass to quantize weights. More...

#include <QuantizeWeightsPass.h>

Collaboration diagram for luci::QuantizeWeightsPass:

Data Structures

struct  Context
 

Public Member Functions

 QuantizeWeightsPass (std::unique_ptr< Context > &&ctx)
 
 QuantizeWeightsPass (loco::DataType input_model_dtype, loco::DataType output_model_dtype, QuantizationGranularity granularity)
 
virtual const char * name (void) const
 
bool run (loco::Graph *graph)
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Detailed Description

Pass to quantize weights.

Definition at line 32 of file QuantizeWeightsPass.h.

Constructor & Destructor Documentation

◆ QuantizeWeightsPass() [1/2]

luci::QuantizeWeightsPass::QuantizeWeightsPass ( std::unique_ptr< Context > &&  ctx)
inline

Definition at line 43 of file QuantizeWeightsPass.h.

43 : _ctx{std::move(ctx)}
44 {
45 // DO NOTHING
46 }

◆ QuantizeWeightsPass() [2/2]

luci::QuantizeWeightsPass::QuantizeWeightsPass ( loco::DataType  input_model_dtype,
loco::DataType  output_model_dtype,
QuantizationGranularity  granularity 
)
inline

Definition at line 49 of file QuantizeWeightsPass.h.

51 {
52 _ctx = std::make_unique<Context>();
53 {
54 _ctx->input_model_dtype = input_model_dtype;
55 _ctx->output_model_dtype = output_model_dtype;
56 _ctx->granularity = granularity;
57 }
58 }

Member Function Documentation

◆ name()

virtual const char * luci::QuantizeWeightsPass::name ( void  ) const
inlinevirtual

Reimplemented from logo::Pass.

Definition at line 59 of file QuantizeWeightsPass.h.

59{ return "luci::QuantizeWeightsPass"; }

◆ run()

bool luci::QuantizeWeightsPass::run ( loco::Graph graph)
virtual

Run the pass.

Returns
false if there was nothing changed

Implements logo::Pass.

Definition at line 26 of file QuantizeWeightsPass.cpp.

27{
28 LOGGER(l);
29 INFO(l) << "QuantizeWeightsPass Start" << std::endl;
30
31 if (_ctx->input_model_dtype != loco::DataType::FLOAT32)
32 throw std::runtime_error("Weights-only quantization supports float32 input only");
33
34 // Quantize weights
35 for (auto node : loco::active_nodes(loco::output_nodes(g)))
36 {
37 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
38 QuantizeWeightsOnly qw(_ctx->input_model_dtype, _ctx->output_model_dtype, _ctx->granularity);
39 circle_node->accept(&qw);
40 }
41
42 INFO(l) << "QuantizeWeightsPass End" << std::endl;
43 return false; // one time run
44}
#define LOGGER(name)
Definition Log.h:65
#define INFO(name)
Definition Log.h:68
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101

References loco::active_nodes(), INFO, LOGGER, and loco::output_nodes().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: