ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::odc::Embedder Class Reference

#include <Embedder.h>

Public Member Functions

void embed (luci::Module *module, const std::string &minmax_path, const EmbedderOptions &opt)
 

Detailed Description

Definition at line 34 of file Embedder.h.

Member Function Documentation

◆ embed()

void onert::odc::Embedder::embed ( luci::Module module,
const std::string &  minmax_path,
const EmbedderOptions opt 
)

Definition at line 74 of file Embedder.cc.

76{
77 if (module == nullptr)
78 throw std::runtime_error{"Input module is nullptr"};
79
80 MinMaxReader mmr{minmax_path};
81
82 for (size_t idx = 0; idx < module->size(); ++idx)
83 {
84 auto graph = module->graph(idx);
85
86 /* read subgraph inputs */
87 const auto input_nodes = loco::input_nodes(graph);
88 const auto n_inputs = input_nodes.size();
89 for (size_t input_idx = 0; input_idx < n_inputs; ++input_idx)
90 {
91 const auto *circle_input = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
92 if (circle_input->index() != input_idx)
93 throw std::runtime_error("Input order in minmax recording does not match to circle");
94
95 auto minmax = mmr.readInput(0, idx, input_idx);
96 auto min = getNthPercentile(minmax.min_vector, opt.min_percentile);
97 auto max = getNthPercentile(minmax.max_vector, opt.max_percentile);
98 auto quantparam = std::make_unique<luci::CircleQuantParam>();
99 quantparam->min.push_back(min);
100 quantparam->max.push_back(max);
101 const auto *circle_node = loco::must_cast<const luci::CircleNode *>(input_nodes[input_idx]);
102 auto mutable_node = const_cast<luci::CircleNode *>(circle_node);
103 mutable_node->quantparam(std::move(quantparam));
104 }
105
106 /* read op outputs */
107 uint32_t n_nodes = graph->nodes()->size();
108 for (uint32_t i = 0; i < n_nodes; ++i)
109 {
110 auto node = loco::must_cast<luci::CircleNode *>(graph->nodes()->at(i));
111 if (not luci::has_node_id(node)) // Skip non-op nodes (e.g. input/const/output)
112 continue;
113 auto op_idx = luci::get_node_id(node);
114 auto minmax = mmr.readOP(0, idx, op_idx);
115 auto min = getNthPercentile(minmax.min_vector, opt.min_percentile);
116 auto max = getNthPercentile(minmax.max_vector, opt.max_percentile);
117 auto quantparam = std::make_unique<luci::CircleQuantParam>();
118 quantparam->min.push_back(min);
119 quantparam->max.push_back(max);
120 auto mutable_node = const_cast<luci::CircleNode *>(node);
121 mutable_node->quantparam(std::move(quantparam));
122 }
123
124 if (!luci::validate(graph))
125 throw std::runtime_error{"Circle after embedding minmax is invalid"};
126 }
127}
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71
CircleNodeID get_node_id(const luci::CircleNode *circle_node)
bool validate(luci::PartitionTable &partition)
bool has_node_id(const luci::CircleNode *circle_node)
float getNthPercentile(std::vector< float > &vector, float percentile)
getNthPercentile calculates the n-th percentile of input vector (0.0 <= n <= 100.0) linear interpolat...
CircleQuantParam * quantparam(void) const

References luci::get_node_id(), luci::has_node_id(), loco::input_nodes(), onert::odc::EmbedderOptions::max_percentile, onert::odc::EmbedderOptions::min_percentile, luci::CircleNode::quantparam(), and luci::validate().

Referenced by onert::odc::Quantizer::quantize().


The documentation for this class was generated from the following files: