ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::compiler::train::TrainingCompiler Class Reference

Class to compile NN package. More...

#include <TrainingCompiler.h>

Collaboration diagram for onert::compiler::train::TrainingCompiler:

Public Member Functions

 TrainingCompiler (const std::shared_ptr< ir::NNPkg > &nnpkg, CompilerOptions *copts, const ir::train::TrainingInfo &training_info)
 Construct a new TrainingCompiler object for an nnpkg.
 
 TrainingCompiler (void)=delete
 Construct a TrainingCompiler object.
 
 ~TrainingCompiler ()=default
 Destroy the TrainingCompiler object.
 
std::shared_ptr< CompilerArtifactcompile (void)
 Do compilation with the options.
 
- Public Member Functions inherited from onert::compiler::ICompiler
virtual ~ICompiler ()=default
 Virtual ICompiler destructor.
 

Detailed Description

Class to compile NN package.

Definition at line 40 of file TrainingCompiler.h.

Constructor & Destructor Documentation

◆ TrainingCompiler() [1/2]

onert::compiler::train::TrainingCompiler::TrainingCompiler ( const std::shared_ptr< ir::NNPkg > &  nnpkg,
CompilerOptions copts,
const ir::train::TrainingInfo training_info 
)
explicit

Construct a new TrainingCompiler object for an nnpkg.

Parameters
[in]nnpkgnnpkg to compile
[in]coptscompiler options
[in]training_infotraining information

Definition at line 48 of file TrainingCompiler.cc.

50 : _model{nnpkg->primary_model()}, _options{copts}, _training_info{training_info}
51{
52 if (nnpkg->model_count() > 1)
53 throw std::runtime_error("TrainingCompiler does not support multiple models yet");
54
55 if (nnpkg->primary_model()->subgraphs_count() > 1)
56 throw std::runtime_error("TrainingCompiler does not support multiple subgraphs yet");
57}

◆ TrainingCompiler() [2/2]

onert::compiler::train::TrainingCompiler::TrainingCompiler ( void  )
delete

Construct a TrainingCompiler object.

◆ ~TrainingCompiler()

onert::compiler::train::TrainingCompiler::~TrainingCompiler ( )
default

Destroy the TrainingCompiler object.

Member Function Documentation

◆ compile()

std::shared_ptr< CompilerArtifact > onert::compiler::train::TrainingCompiler::compile ( void  )
virtual

Do compilation with the options.

Returns
std::shared_ptr<CompilerArtifact> Executors as a result of compilation

Implements onert::compiler::ICompiler.

Definition at line 59 of file TrainingCompiler.cc.

60{
61 /***************************************************
62 * Prepare compilation phase
63 ***************************************************/
64 if (!_options)
65 throw std::runtime_error{"Empty compile option"};
66
67 // Mode check
68 // TODO handle option for each model
69 if (_options->he_profiling_mode)
70 {
71 if (!_options->he_scheduler)
72 throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling.");
73
74 if (_options->executor != "Dataflow")
75 throw std::runtime_error("Profiling mode works only with 'Dataflow' executor");
76 }
77
78 _options->forceInternalOptions();
79 _options->verboseOptions();
80
81 auto custom_kernel_builder = _model->getKernelBuilder();
82
83 _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) {
84 auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
85 // Mandatory passes
86 compiler::pass::PassRunner{}
87 .append(std::make_unique<compiler::pass::ConstantOutputPass>(subg))
88 .append(std::make_unique<compiler::pass::OddOutputPass>(subg))
89 .run();
90
91 // Optimizations
92 compiler::pass::PassRunner{}
93 .append(std::make_unique<compiler::pass::UnusedOperandEliminationPass>(subg))
94 .run();
95 });
96
97 std::unordered_map<ir::SubgraphIndex, std::shared_ptr<ir::train::TrainableGraph>>
98 trainable_subgraphs;
99
100 if (_model->hasOnly<ir::Graph>())
101 {
102 // Create trainable subgraphs by copy and converting inference model
103 _model->iterate([&](const ir::SubgraphIndex &subg_index, const ir::IGraph &graph) {
104 const auto &subg = nnfw::misc::polymorphic_downcast<const ir::Graph &>(graph);
105 // Create TrainableGraph by copying Graph
106 auto trainable_subg = std::make_shared<ir::train::TrainableGraph>(subg);
107
108 // Convert operations to trainable operations
109 auto converter = TrainableOperationConverter{*trainable_subg, &_training_info};
110 ir::OperationIndex min_trainable_op_idx;
111 subg.operations().iterate(
112 [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &op) {
113 auto trainable_op = converter(op);
114 if (_training_info.getTrainableOps().find(op_index) !=
115 std::end(_training_info.getTrainableOps()))
116 {
117 trainable_op->enableWeightsUpdate();
118 if (op_index.value() < min_trainable_op_idx.value())
119 {
120 min_trainable_op_idx = op_index;
121 }
122 }
123 [[maybe_unused]] auto gen_index =
124 trainable_subg->replaceOperation(op_index, std::move(trainable_op));
125 assert(gen_index == op_index);
126 });
127
128 for (ir::OperationIndex idx{min_trainable_op_idx};
129 idx.value() < trainable_subg->operations().size(); idx++)
130 {
131 trainable_subg->enableBackward(idx);
132 }
133
134 trainable_subgraphs[subg_index] = std::move(trainable_subg);
135 });
136 }
137 else
138 {
139 // TODO Support models that have TrainableGraphs
140 throw std::runtime_error("TrainingCompiler: Invalid model");
141 }
142
143 // operation
144 _model.reset();
145
146 // TODO Handle dump level for each model
147 auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level);
148 onert::dumper::dot::DotDumper dot_dumper(dump_level);
149
150 for (const auto &[subg_index, subg] : trainable_subgraphs)
151 {
152 dot_dumper.dump(*subg, nnfw::misc::str("before_loss_insertion-", subg_index.value()));
153 }
154
155 // Apply pass for trainable subgraphs
156 for (auto &&[subg_index, trainable_subg] : trainable_subgraphs)
157 {
158 compiler::pass::PassRunner{}
159 .append(std::make_unique<train::pass::LossInsertionPass>(*trainable_subg, &_training_info,
160 subg_index))
161 .run();
162 }
163
164 for (const auto &[subg_index, subg] : trainable_subgraphs)
165 {
166 dot_dumper.dump(*subg, nnfw::misc::str("after_loss_insertion-", subg_index.value()));
167 }
168
169 for (auto &&[subg_index, subg] : trainable_subgraphs)
170 {
171 subg->updateGraphDependency();
172 subg->verify();
173
174 dot_dumper.dump(*subg,
175 nnfw::misc::str("after_initializing_training_usedefs-", subg_index.value()));
176 }
177
178 // Change input shape according to batch_size
179 for (auto &&pair : trainable_subgraphs)
180 {
181 auto trainable_subg = pair.second;
182
183 for (const auto &ind : trainable_subg->getInputs())
184 {
185 auto &input = trainable_subg->operands().at(ind);
186 auto new_shape = input.info().shape();
187 // TODO Consider batch size index
188 if (new_shape.dim(0) != 1)
189 throw std::runtime_error("the first dim is not 1. It is not supported yet.");
190 new_shape.dim(0) = _training_info.batchSize();
191 input.info().shape(new_shape);
192 }
193 }
194
195 /***************************************************
196 * Backend independent analysis & optimization phase
197 ***************************************************/
198 // Tracing context
199 auto tracing_ctx = std::make_unique<util::TracingCtx>();
200
201 // Lower: Assign backend
202 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::train::LoweredTrainableGraph>>
203 lowered_subgs;
204 {
205 for (auto &&[subg_index, trainable_subg] : trainable_subgraphs)
206 {
207 // Lower: Assign backend
208 lowered_subgs[subg_index] =
209 std::make_unique<compiler::train::LoweredTrainableGraph>(*trainable_subg, *_options);
210 // Set tracing_ctx for copied graph
211 tracing_ctx->setSubgraphIndex(&(lowered_subgs[subg_index]->graph()), subg_index.value());
212 }
213 }
214
215 for (const auto &[subg_index, lowered_subg] : lowered_subgs)
216 {
217 dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_subg-", subg_index.value()));
218 }
219
220 // Set operands' info for back propagation as default tensor info
221 for (const auto &pair : lowered_subgs)
222 {
223 auto lowered_subg = pair.second.get();
224 auto &tgraph = lowered_subg->trainable_graph();
225 tgraph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &obj) {
226 if (!obj.isConstant())
227 {
228 auto bwd_operand = std::make_unique<ir::Operand>(obj);
229 [[maybe_unused]] const auto gen_index =
230 tgraph.addBackwardOperand(index, std::move(bwd_operand));
231 assert(gen_index == index);
232 }
233 });
234 }
235
236 // Shape inference.
237 {
238 // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
239 // recursively
240 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
241 createStaticShapeInferers(lowered_subgs);
242
243 const auto primary_subg_idx = ir::SubgraphIndex{0};
244 inferers.at(primary_subg_idx)->infer();
245
246 for (const auto &pair_inferer : inferers)
247 {
248 const auto inferer = pair_inferer.second.get();
249 inferer->dump();
250 }
251
252 // NOTE StaticBackwardShapeInferer is allocated for each subgraph,
253 // so it does not support models that have controlflow operations yet.
254 for (auto &&pair : lowered_subgs)
255 {
256 auto &lowered_subg = pair.second;
257 auto inferer = std::make_unique<StaticBackwardShapeInferer>(lowered_subg.get());
258 inferer->infer();
259 inferer->dump();
260 }
261 }
262
263 // Shape validation
264 for (const auto &pair : lowered_subgs)
265 {
266 auto &lowered_subg = pair.second;
267 compiler::ShapeValidator{lowered_subg->graph()}();
268 }
269
270 // TODO Validate shapes of the tensors for back propagation
271
272 /*************************************************************
273 * Backend independent analysis & optimization phase finished
274 *************************************************************/
275 auto executors = std::make_shared<exec::train::TrainableExecutors>();
276 for (auto &&[subg_index, lowered_subg] : lowered_subgs)
277 {
278 auto const model_index = ir::ModelIndex{0};
279 auto const indexed_ranks = lowered_subg->indexed_ranks();
280
281 ir::OperationDumper dumper("Executor generation of Subgraph " +
282 std::to_string(subg_index.value()));
283 lowered_subg->graph().operations().iterate(
284 [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); });
285
286 ExecutorFactoryArgs args;
287 args.tracing_ctx = tracing_ctx.get();
288 args.options = _options;
289 args.model_index = model_index;
290 args.custom_kernel_builder = custom_kernel_builder;
291 auto executor = std::unique_ptr<exec::IExecutor>{
292 ExecutorFactory::get().create(std::move(lowered_subg), executors, args, _training_info)};
293 executor->setIndexedRanks(indexed_ranks);
294 executors->emplace(model_index, subg_index, std::move(executor));
295 }
296
297 /********************************
298 * Code generation phase finished
299 ********************************/
300 return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx));
301}
exec::IExecutor * create(std::unique_ptr< compiler::LoweredGraph > lowered_graph, const std::shared_ptr< exec::IExecutors > &executors, const ExecutorFactoryArgs &args)
static ExecutorFactory & get()
const std::set< OperationIndex > & getTrainableOps() const
T value() const
Return underlying value.
Definition Index.h:139
args
Definition infer.py:21
std::string str(Args &&...args)
::onert::util::Index< uint32_t, OperationIndexTag > OperationIndex
Definition Index.h:32
::onert::util::Index< uint16_t, ModelIndexTag > ModelIndex
Definition Index.h:44
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:35
::onert::util::Index< uint16_t, SubgraphIndexTag > SubgraphIndex
Definition Index.h:41
void forceInternalOptions()
Force default values of CompilerOptions for correct compilations.
void verboseOptions()
Print option value.
virtual void setIndexedRanks(std::shared_ptr< ir::OperationIndexMap< int64_t > >)=0
Set an ordering on operations.

References onert::ir::IOperation::accept(), onert::compiler::pass::PassRunner::append(), onert::ir::train::TrainingInfo::batchSize(), onert::compiler::ExecutorFactory::create(), onert::dumper::dot::DotDumper::dump(), onert::compiler::CompilerOptions::executor, onert::compiler::CompilerOptions::forceInternalOptions(), onert::compiler::ExecutorFactory::get(), onert::ir::train::TrainingInfo::getTrainableOps(), onert::compiler::CompilerOptions::graph_dump_level, onert::compiler::CompilerOptions::he_profiling_mode, onert::compiler::CompilerOptions::he_scheduler, onert::compiler::pass::PassRunner::run(), run(), onert::exec::IExecutor::setIndexedRanks(), nnfw::misc::str(), onert::util::Index< T, DummyTag >::value(), and onert::compiler::CompilerOptions::verboseOptions().


The documentation for this class was generated from the following files: