17#include "CircleModel.h"
19#include <mio/circle/schema_generated.h>
32std::vector<uint8_t> read_model(
const std::string &model_path)
34 std::ifstream file_stream(model_path, std::ios::in | std::ios::binary | std::ifstream::ate);
35 if (!file_stream.is_open())
37 throw std::runtime_error(
"Failed to open file: " + model_path);
40 std::streamsize
size = file_stream.tellg();
41 file_stream.seekg(0, std::ios::beg);
43 std::vector<uint8_t> buffer(
size);
44 if (!file_stream.read(
reinterpret_cast<char *
>(buffer.data()),
size))
46 throw std::runtime_error(
"Failed to read file: " + model_path);
52std::unique_ptr<luci::Module> load_module(
const std::vector<uint8_t> &model_buffer)
54 flatbuffers::Verifier verifier{model_buffer.data(), model_buffer.size()};
55 if (!circle::VerifyModelBuffer(verifier))
57 throw std::runtime_error(
"Verification of the model failed");
62 return importer.importModule(model_buffer.data(), model_buffer.size());
69 : _module(
module), _buffer{
std::make_unique<
std::vector<uint8_t>>()}
76 bool store(
const char *ptr,
const size_t size)
const override
78 _buffer->resize(
size);
79 std::copy(ptr, ptr +
size, _buffer->begin());
83 std::vector<uint8_t> get_buffer() {
return *_buffer; }
87 std::unique_ptr<std::vector<uint8_t>> _buffer;
90template <
typename NodeType>
91std::vector<Shape> extract_shapes(
const std::vector<loco::Node *> &nodes)
93 std::vector<Shape> shapes;
94 for (
const auto &loco_node : nodes)
96 std::vector<Dim>
dims;
97 const auto circle_node = loco::must_cast<const NodeType *>(loco_node);
98 for (uint32_t dim_idx = 0; dim_idx < circle_node->rank(); dim_idx++)
100 if (circle_node->dim(dim_idx).known())
102 const int32_t dim_val = circle_node->dim(dim_idx).value();
125 BufferModelContract contract(
module());
127 if (!exporter.
invoke(&contract))
129 throw std::runtime_error(
"Exporting buffer from the model failed");
132 auto model_buffer = contract.get_buffer();
133 stream.write(
reinterpret_cast<const char *
>(model_buffer.data()), model_buffer.size());
136 throw std::runtime_error(
"Failed to write to output stream");
142 std::ofstream out_stream(output_path, std::ios::out | std::ios::binary);
luci::Module * module()
Get the loaded model in luci::Module representation.
std::vector< Shape > input_shapes() const
Get input shapes of the loaded model.
~CircleModel()
Dtor of CircleModel. Note that explicit declaration is needed to satisfy forward declaration + unique...
std::vector< Shape > output_shapes() const
Get output shapes of the loaded model.
CircleModel(const std::vector< uint8_t > &buffer)
Initialize the model with buffer representation.
void save(std::ostream &stream)
Save the model to the output stream.
bool invoke(Contract *) const
static GraphBuilderRegistry & get()
Collection of 'loco::Graph's.
std::vector< int > dims(const std::string &src)
std::vector< Node * > input_nodes(const Graph *)
std::vector< Node * > output_nodes(Graph *)
virtual bool store(const char *ptr, const size_t size) const =0
virtual luci::Module * module(void) const =0