ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::exporter::CircleExporter Class Reference

#include <CircleExporter.h>

Public Member Functions

 CircleExporter (const std::string &source, const std::string &path)
 
 ~CircleExporter ()
 
void updateWeight (const std::unique_ptr< onert::exec::Execution > &exec)
 
void updateMetadata (const std::unique_ptr< onert::ir::train::TrainingInfo > &training_info)
 

Detailed Description

Definition at line 36 of file CircleExporter.h.

Constructor & Destructor Documentation

◆ CircleExporter()

onert::exporter::CircleExporter::CircleExporter ( const std::string &  source,
const std::string &  path 
)

Definition at line 31 of file CircleExporter.cc.

32 : _path{path}, _data{}, _model{nullptr}
33{
34 // make sure the architecture is little endian before direct access to flatbuffers
35 assert(FLATBUFFERS_LITTLEENDIAN);
36
37 std::ifstream src(source.c_str(), std::ios::binary);
38 if (src.is_open())
39 {
40 src.seekg(0, std::ios::end);
41 _data.resize(src.tellg());
42 src.seekg(0, std::ios::beg);
43 src.read(&_data[0], static_cast<std::streamsize>(_data.size()));
44 src.close();
45 }
46
47 if (_data.size() == 0)
48 throw std::runtime_error("Invalid source file");
49
50 const auto model = ::circle::GetModel(_data.data());
51 if (!model)
52 throw std::runtime_error("Failed to load original circle file");
53 _model.reset(model->UnPack());
54}

◆ ~CircleExporter()

onert::exporter::CircleExporter::~CircleExporter ( )

Definition at line 56 of file CircleExporter.cc.

56{ finish(); }

Member Function Documentation

◆ updateMetadata()

void onert::exporter::CircleExporter::updateMetadata ( const std::unique_ptr< onert::ir::train::TrainingInfo > &  training_info)

Definition at line 98 of file CircleExporter.cc.

99{
100 TrainInfoBuilder tbuilder(training_info);
101 bool found = false;
102 for (const auto &meta : _model->metadata)
103 {
104 if (meta->name == std::string{loader::TRAININFO_METADATA_NAME})
105 {
106 std::lock_guard<std::mutex> guard(_mutex);
107 const uint32_t buf_idx = meta->buffer;
108 auto &buffer = _model->buffers.at(buf_idx);
109
110 if (tbuilder.size() != buffer->data.size())
111 {
112 buffer->data.resize(tbuilder.size());
113 buffer->size = tbuilder.size();
114 }
115
116 memcpy(buffer->data.data(), tbuilder.get(), tbuilder.size());
117 found = true;
118 break;
119 }
120 }
121
122 if (!found)
123 {
124 std::lock_guard<std::mutex> guard(_mutex);
125 auto buffer = std::make_unique<::circle::BufferT>();
126 buffer->size = tbuilder.size();
127 buffer->data.resize(buffer->size);
128 memcpy(buffer->data.data(), tbuilder.get(), buffer->size);
129
130 auto meta = std::make_unique<::circle::MetadataT>();
131 meta->name = std::string{loader::TRAININFO_METADATA_NAME};
132 meta->buffer = _model->buffers.size();
133
134 _model->buffers.push_back(std::move(buffer));
135 _model->metadata.push_back(std::move(meta));
136 }
137}
const char *const TRAININFO_METADATA_NAME

References onert::exporter::TrainInfoBuilder::get(), onert::exporter::TrainInfoBuilder::size(), and onert::loader::TRAININFO_METADATA_NAME.

Referenced by nnfw_session::train_export_circleplus().

◆ updateWeight()

void onert::exporter::CircleExporter::updateWeight ( const std::unique_ptr< onert::exec::Execution > &  exec)

Definition at line 58 of file CircleExporter.cc.

59{
60 exec->iterateTrainableTensors(
61 [&](const ir::OperandIndex &idx, const backend::train::ITrainableTensor *tensor) {
62 std::lock_guard<std::mutex> guard(_mutex);
63 const auto &subgs = _model->subgraphs;
64 if (subgs.size() != 1)
65 throw std::runtime_error("Circle does not has valid subgraph or has multiple subgraphs");
66
67 if (!idx.valid())
68 throw std::runtime_error("Trainable tensor is invalid");
69
70 uint32_t buf_idx = -1;
71 const auto &subg = subgs.at(0); // Get 1st subgraph
72 if (idx.value() >= subg->tensors.size())
73 {
74 auto buffer = std::make_unique<::circle::BufferT>();
75 buffer->size = tensor->total_size();
76 buffer->data.resize(buffer->size);
77
78 buf_idx = _model->buffers.size();
79 _model->buffers.push_back(std::move(buffer));
80 }
81 else
82 {
83 buf_idx = subg->tensors.at(idx.value())->buffer;
84 if (buf_idx >= _model->buffers.size())
85 throw std::runtime_error("Buffer for trainable tensors is invalid");
86 }
87
88 const auto &buffer = _model->buffers.at(buf_idx);
89
90 auto org_buf_sz = buffer->data.size();
91 if (org_buf_sz != tensor->total_size())
92 throw std::runtime_error("Trained tensor buffer size does not match original tensor's one");
93
94 memcpy(buffer->data.data(), tensor->buffer(), org_buf_sz);
95 });
96}
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:33

References onert::util::Index< T, DummyTag >::valid(), and onert::util::Index< T, DummyTag >::value().

Referenced by nnfw_session::train_export_circleplus().


The documentation for this class was generated from the following files: