ONE - On-device Neural Engine
Loading...
Searching...
No Matches
CircleModel.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "CircleModel.h"
18
19#include <mio/circle/schema_generated.h>
20
21#include <luci/Importer.h>
22#include <luci/CircleExporter.h>
24
25#include <fstream>
26
27using namespace circle_resizer;
28
29namespace
30{
31
32std::vector<uint8_t> read_model(const std::string &model_path)
33{
34 std::ifstream file_stream(model_path, std::ios::in | std::ios::binary | std::ifstream::ate);
35 if (!file_stream.is_open())
36 {
37 throw std::runtime_error("Failed to open file: " + model_path);
38 }
39
40 std::streamsize size = file_stream.tellg();
41 file_stream.seekg(0, std::ios::beg);
42
43 std::vector<uint8_t> buffer(size);
44 if (!file_stream.read(reinterpret_cast<char *>(buffer.data()), size))
45 {
46 throw std::runtime_error("Failed to read file: " + model_path);
47 }
48
49 return buffer;
50}
51
52std::unique_ptr<luci::Module> load_module(const std::vector<uint8_t> &model_buffer)
53{
54 flatbuffers::Verifier verifier{model_buffer.data(), model_buffer.size()};
55 if (!circle::VerifyModelBuffer(verifier))
56 {
57 throw std::runtime_error("Verification of the model failed");
58 }
59
61 luci::Importer importer(source_ptr);
62 return importer.importModule(model_buffer.data(), model_buffer.size());
63}
64
65class BufferModelContract : public luci::CircleExporter::Contract
66{
67public:
68 BufferModelContract(luci::Module *module)
69 : _module(module), _buffer{std::make_unique<std::vector<uint8_t>>()}
70 {
71 assert(_module); // FIX_CALLER_UNLESS
72 }
73
74 luci::Module *module() const override { return _module; }
75
76 bool store(const char *ptr, const size_t size) const override
77 {
78 _buffer->resize(size);
79 std::copy(ptr, ptr + size, _buffer->begin());
80 return true;
81 }
82
83 std::vector<uint8_t> get_buffer() { return *_buffer; }
84
85private:
86 luci::Module *_module;
87 std::unique_ptr<std::vector<uint8_t>> _buffer; // note that the store method has to be const
88};
89
90template <typename NodeType>
91std::vector<Shape> extract_shapes(const std::vector<loco::Node *> &nodes)
92{
93 std::vector<Shape> shapes;
94 for (const auto &loco_node : nodes)
95 {
96 std::vector<Dim> dims;
97 const auto circle_node = loco::must_cast<const NodeType *>(loco_node);
98 for (uint32_t dim_idx = 0; dim_idx < circle_node->rank(); dim_idx++)
99 {
100 if (circle_node->dim(dim_idx).known())
101 {
102 const int32_t dim_val = circle_node->dim(dim_idx).value();
103 dims.push_back(Dim{dim_val});
104 }
105 else
106 {
107 dims.push_back(Dim{-1});
108 }
109 }
110 shapes.push_back(Shape{dims});
111 }
112 return shapes;
113}
114
115} // namespace
116
117CircleModel::CircleModel(const std::vector<uint8_t> &buffer) : _module{load_module(buffer)} {}
118
119CircleModel::CircleModel(const std::string &model_path) : CircleModel(read_model(model_path)) {}
120
121luci::Module *CircleModel::module() { return _module.get(); }
122
123void CircleModel::save(std::ostream &stream)
124{
125 BufferModelContract contract(module());
126 luci::CircleExporter exporter;
127 if (!exporter.invoke(&contract))
128 {
129 throw std::runtime_error("Exporting buffer from the model failed");
130 }
131
132 auto model_buffer = contract.get_buffer();
133 stream.write(reinterpret_cast<const char *>(model_buffer.data()), model_buffer.size());
134 if (!stream.good())
135 {
136 throw std::runtime_error("Failed to write to output stream");
137 }
138}
139
140void CircleModel::save(const std::string &output_path)
141{
142 std::ofstream out_stream(output_path, std::ios::out | std::ios::binary);
143 save(out_stream);
144}
145
146std::vector<Shape> CircleModel::input_shapes() const
147{
148 return extract_shapes<luci::CircleInput>(loco::input_nodes(_module->graph()));
149}
150
151std::vector<Shape> CircleModel::output_shapes() const
152{
153 return extract_shapes<luci::CircleOutput>(loco::output_nodes(_module->graph()));
154}
155
156CircleModel::~CircleModel() = default;
luci::Module * module()
Get the loaded model in luci::Module representation.
std::vector< Shape > input_shapes() const
Get input shapes of the loaded model.
~CircleModel()
Dtor of CircleModel. Note that explicit declaration is needed to satisfy forward declaration + unique...
std::vector< Shape > output_shapes() const
Get output shapes of the loaded model.
CircleModel(const std::vector< uint8_t > &buffer)
Initialize the model with buffer representation.
void save(std::ostream &stream)
Save the model to the output stream.
bool invoke(Contract *) const
static GraphBuilderRegistry & get()
Collection of 'loco::Graph's.
Definition Module.h:33
std::vector< int > dims(const std::string &src)
Definition Utils.h:35
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101
int32_t size[5]
Definition Slice.cpp:35
Definition Shape.h:28
virtual bool store(const char *ptr, const size_t size) const =0
virtual luci::Module * module(void) const =0