ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::odc::Embedder Class Reference

#include <Embedder.h>

Public Member Functions

void embed (luci::Module *module, const std::string &minmax_path, const EmbedderOptions &opt)
 

Detailed Description

Definition at line 32 of file Embedder.h.

Member Function Documentation

◆ embed()

void onert::odc::Embedder::embed ( luci::Module module,
const std::string &  minmax_path,
const EmbedderOptions opt 
)

Definition at line 72 of file Embedder.cc.

74{
75 if (module == nullptr)
76 throw std::runtime_error{"Input module is nullptr"};
77
78 MinMaxReader mmr{minmax_path};
79
80 for (size_t idx = 0; idx < module->size(); ++idx)
81 {
82 auto graph = module->graph(idx);
83
84 /* read subgraph inputs */
85 const auto input_nodes = loco::input_nodes(graph);
86 const auto n_inputs = input_nodes.size();
87 for (size_t input_idx = 0; input_idx < n_inputs; ++input_idx)
88 {
89 const auto *circle_input = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
90 if (circle_input->index() != input_idx)
91 throw std::runtime_error("Input order in minmax recording does not match to circle");
92
93 auto minmax = mmr.readInput(0, idx, input_idx);
94 auto min = getNthPercentile(minmax.min_vector, opt.min_percentile);
95 auto max = getNthPercentile(minmax.max_vector, opt.max_percentile);
96 auto quantparam = std::make_unique<luci::CircleQuantParam>();
97 quantparam->min.push_back(min);
98 quantparam->max.push_back(max);
99 const auto *circle_node = loco::must_cast<const luci::CircleNode *>(input_nodes[input_idx]);
100 auto mutable_node = const_cast<luci::CircleNode *>(circle_node);
101 mutable_node->quantparam(std::move(quantparam));
102 }
103
104 /* read op outputs */
105 uint32_t n_nodes = graph->nodes()->size();
106 for (uint32_t i = 0; i < n_nodes; ++i)
107 {
108 auto node = loco::must_cast<luci::CircleNode *>(graph->nodes()->at(i));
109 if (not luci::has_node_id(node)) // Skip non-op nodes (e.g. input/const/output)
110 continue;
111 auto op_idx = luci::get_node_id(node);
112 auto minmax = mmr.readOP(0, idx, op_idx);
113 auto min = getNthPercentile(minmax.min_vector, opt.min_percentile);
114 auto max = getNthPercentile(minmax.max_vector, opt.max_percentile);
115 auto quantparam = std::make_unique<luci::CircleQuantParam>();
116 quantparam->min.push_back(min);
117 quantparam->max.push_back(max);
118 auto mutable_node = const_cast<luci::CircleNode *>(node);
119 mutable_node->quantparam(std::move(quantparam));
120 }
121
122 if (!luci::validate(graph))
123 throw std::runtime_error{"Circle after embedding minmax is invalid"};
124 }
125}
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71
CircleNodeID get_node_id(const luci::CircleNode *circle_node)
bool validate(luci::PartitionTable &partition)
bool has_node_id(const luci::CircleNode *circle_node)
float getNthPercentile(std::vector< float > &vector, float percentile)
getNthPercentile calculates the n-th percentile of input vector (0.0 <= n <= 100.0) linear interpolat...
CircleQuantParam * quantparam(void) const

References luci::get_node_id(), luci::has_node_id(), loco::input_nodes(), onert::odc::EmbedderOptions::max_percentile, onert::odc::EmbedderOptions::min_percentile, luci::CircleNode::quantparam(), and luci::validate().

Referenced by onert::odc::Quantizer::quantize().


The documentation for this class was generated from the following files: