ONE - On-device Neural Engine
Loading...
Searching...
No Matches
minmax_embedder::Embedder Class Reference

#include <Embedder.h>

Public Member Functions

void embed (const std::string &output_path, const std::string &input_path, const std::string &minmax_path, const EmbedderOptions &)
 

Detailed Description

Definition at line 31 of file Embedder.h.

Member Function Documentation

◆ embed()

void minmax_embedder::Embedder::embed ( const std::string &  output_path,
const std::string &  input_path,
const std::string &  minmax_path,
const EmbedderOptions opt 
)

Definition at line 77 of file Embedder.cpp.

79{
80 // Load model from the file
81 luci::ImporterEx importerex;
82 auto module = importerex.importVerifyModule(input_path);
83 if (module.get() == nullptr)
84 throw std::runtime_error{"Input circle is invalid"};
85
86 h5::Reader mmr{minmax_path};
87
88 for (size_t idx = 0; idx < module->size(); ++idx)
89 {
90 auto graph = module->graph(idx);
91
92 /* read subgraph inputs */
93 const auto input_nodes = loco::input_nodes(graph);
94 const auto n_inputs = input_nodes.size();
95 for (size_t input_idx = 0; input_idx < n_inputs; ++input_idx)
96 {
97 const auto *circle_input = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
98 if (circle_input->index() != input_idx)
99 throw std::runtime_error("Input order in minmax recording does not match to circle");
100
101 auto minmax = mmr.read_input(0, idx, input_idx);
102 auto min = getNthPercentile(minmax.min_vector, opt.min_percentile);
103 auto max = getNthPercentile(minmax.max_vector, opt.max_percentile);
104 auto quantparam = std::make_unique<luci::CircleQuantParam>();
105 quantparam->min.push_back(min);
106 quantparam->max.push_back(max);
107 const auto *circle_node = loco::must_cast<const luci::CircleNode *>(input_nodes[input_idx]);
108 auto mutable_node = const_cast<luci::CircleNode *>(circle_node);
109 mutable_node->quantparam(std::move(quantparam));
110 }
111
112 /* read op outputs */
113 uint32_t n_nodes = graph->nodes()->size();
114 for (uint32_t i = 0; i < n_nodes; ++i)
115 {
116 auto node = loco::must_cast<luci::CircleNode *>(graph->nodes()->at(i));
117 if (not luci::has_node_id(node)) // Skip non-op nodes (e.g. input/const/output)
118 continue;
119 auto op_idx = luci::get_node_id(node);
120 auto minmax = mmr.read(0, idx, op_idx);
121 auto min = getNthPercentile(minmax.min_vector, opt.min_percentile);
122 auto max = getNthPercentile(minmax.max_vector, opt.max_percentile);
123 auto quantparam = std::make_unique<luci::CircleQuantParam>();
124 quantparam->min.push_back(min);
125 quantparam->max.push_back(max);
126 auto mutable_node = const_cast<luci::CircleNode *>(node);
127 mutable_node->quantparam(std::move(quantparam));
128 }
129
130 if (!luci::validate(graph))
131 throw std::runtime_error{"Circle after embedding minmax is invalid"};
132 }
133
134 // Export to output Circle file
135 luci::CircleExporter exporter;
136
137 luci::CircleFileExpContract contract(module.get(), output_path);
138
139 if (!exporter.invoke(&contract))
140 throw std::runtime_error{"Failed to export circle"};
141}
bool invoke(Contract *) const
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71
CircleNodeID get_node_id(const luci::CircleNode *circle_node)
bool validate(luci::PartitionTable &partition)
bool has_node_id(const luci::CircleNode *circle_node)
float getNthPercentile(std::vector< float > &vector, float percentile)
getNthPercentile calculates the n-th percentile of input vector (0.0 <= n <= 100.0) linear interpolat...
CircleQuantParam * quantparam(void) const

References luci::get_node_id(), luci::has_node_id(), loco::input_nodes(), luci::CircleExporter::invoke(), minmax_embedder::EmbedderOptions::max_percentile, minmax_embedder::EmbedderOptions::min_percentile, luci::CircleNode::quantparam(), and luci::validate().

Referenced by entry().


The documentation for this class was generated from the following files: