ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::exporter::CircleExporter Class Reference

#include <CircleExporter.h>

Public Member Functions

 CircleExporter (const std::string &source, const std::string &path)
 
 ~CircleExporter ()
 
void updateWeight (const std::unique_ptr< onert::exec::Execution > &exec)
 
void updateMetadata (const std::unique_ptr< onert::ir::train::TrainingInfo > &training_info)
 

Detailed Description

Definition at line 44 of file CircleExporter.h.

Constructor & Destructor Documentation

◆ CircleExporter()

onert::exporter::CircleExporter::CircleExporter ( const std::string &  source,
const std::string &  path 
)

Definition at line 32 of file CircleExporter.cc.

33 : _path{path}, _data{}, _model{nullptr}
34{
35 // make sure the architecture is little endian before direct access to flatbuffers
36 assert(FLATBUFFERS_LITTLEENDIAN);
37
38 std::ifstream src(source.c_str(), std::ios::binary);
39 if (src.is_open())
40 {
41 src.seekg(0, std::ios::end);
42 _data.resize(src.tellg());
43 src.seekg(0, std::ios::beg);
44 src.read(&_data[0], static_cast<std::streamsize>(_data.size()));
45 src.close();
46 }
47
48 if (_data.size() == 0)
49 throw std::runtime_error("Invalid source file");
50
51 const auto model = ::circle::GetModel(_data.data());
52 if (!model)
53 throw std::runtime_error("Failed to load original circle file");
54 _model.reset(model->UnPack());
55}

◆ ~CircleExporter()

onert::exporter::CircleExporter::~CircleExporter ( )

Definition at line 57 of file CircleExporter.cc.

57{ finish(); }

Member Function Documentation

◆ updateMetadata()

void onert::exporter::CircleExporter::updateMetadata ( const std::unique_ptr< onert::ir::train::TrainingInfo > &  training_info)

Definition at line 99 of file CircleExporter.cc.

100{
101 const char *const TRAININFO_METADATA_NAME = "CIRCLE_TRAINING";
102
103 TrainInfoBuilder tbuilder(training_info);
104 bool found = false;
105 for (const auto &meta : _model->metadata)
106 {
107 if (meta->name == std::string{TRAININFO_METADATA_NAME})
108 {
109 std::lock_guard<std::mutex> guard(_mutex);
110 const uint32_t buf_idx = meta->buffer;
111 auto &buffer = _model->buffers.at(buf_idx);
112
113 if (tbuilder.size() != buffer->data.size())
114 {
115 buffer->data.resize(tbuilder.size());
116 buffer->size = tbuilder.size();
117 }
118
119 memcpy(buffer->data.data(), tbuilder.get(), tbuilder.size());
120 found = true;
121 break;
122 }
123 }
124
125 if (!found)
126 {
127 std::lock_guard<std::mutex> guard(_mutex);
128 auto buffer = std::make_unique<::circle::BufferT>();
129 buffer->size = tbuilder.size();
130 buffer->data.resize(buffer->size);
131 memcpy(buffer->data.data(), tbuilder.get(), buffer->size);
132
133 auto meta = std::make_unique<::circle::MetadataT>();
134 meta->name = std::string{TRAININFO_METADATA_NAME};
135 meta->buffer = _model->buffers.size();
136
137 _model->buffers.push_back(std::move(buffer));
138 _model->metadata.push_back(std::move(meta));
139 }
140}
const char *const TRAININFO_METADATA_NAME

References onert::exporter::TrainInfoBuilder::get(), and onert::exporter::TrainInfoBuilder::size().

Referenced by nnfw_session::train_export_circleplus().

◆ updateWeight()

void onert::exporter::CircleExporter::updateWeight ( const std::unique_ptr< onert::exec::Execution > &  exec)

Definition at line 59 of file CircleExporter.cc.

60{
61 exec->iterateTrainableTensors(
62 [&](const ir::OperandIndex &idx, const backend::train::ITrainableTensor *tensor) {
63 std::lock_guard<std::mutex> guard(_mutex);
64 const auto &subgs = _model->subgraphs;
65 if (subgs.size() != 1)
66 throw std::runtime_error("Circle does not has valid subgraph or has multiple subgraphs");
67
68 if (!idx.valid())
69 throw std::runtime_error("Trainable tensor is invalid");
70
71 uint32_t buf_idx = -1;
72 const auto &subg = subgs.at(0); // Get 1st subgraph
73 if (idx.value() >= subg->tensors.size())
74 {
75 auto buffer = std::make_unique<::circle::BufferT>();
76 buffer->size = tensor->total_size();
77 buffer->data.resize(buffer->size);
78
79 buf_idx = _model->buffers.size();
80 _model->buffers.push_back(std::move(buffer));
81 }
82 else
83 {
84 buf_idx = subg->tensors.at(idx.value())->buffer;
85 if (buf_idx >= _model->buffers.size())
86 throw std::runtime_error("Buffer for trainable tensors is invalid");
87 }
88
89 const auto &buffer = _model->buffers.at(buf_idx);
90
91 auto org_buf_sz = buffer->data.size();
92 if (org_buf_sz != tensor->total_size())
93 throw std::runtime_error("Trained tensor buffer size does not match original tensor's one");
94
95 memcpy(buffer->data.data(), tensor->buffer(), org_buf_sz);
96 });
97}
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:35

References onert::util::Index< T, DummyTag >::valid(), and onert::util::Index< T, DummyTag >::value().

Referenced by nnfw_session::train_export_circleplus().


The documentation for this class was generated from the following files: