ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::compiler::train::TrainingCompiler Class Reference

Class to compile NN package. More...

#include <TrainingCompiler.h>

Collaboration diagram for onert::compiler::train::TrainingCompiler:

Public Member Functions

 TrainingCompiler (const std::shared_ptr< ir::NNPkg > &nnpkg, CompilerOptions *copts, const ir::train::TrainingInfo &training_info)
 Construct a new TrainingCompiler object for an nnpkg.
 
 TrainingCompiler (void)=delete
 Construct a TrainingCompiler object.
 
 ~TrainingCompiler ()=default
 Destroy the TrainingCompiler object.
 
std::shared_ptr< CompilerArtifactcompile (void)
 Do compilation with the options.
 
- Public Member Functions inherited from onert::compiler::ICompiler
virtual ~ICompiler ()=default
 Virtual ICompiler destructor.
 

Detailed Description

Class to compile NN package.

Definition at line 36 of file TrainingCompiler.h.

Constructor & Destructor Documentation

◆ TrainingCompiler() [1/2]

onert::compiler::train::TrainingCompiler::TrainingCompiler ( const std::shared_ptr< ir::NNPkg > &  nnpkg,
CompilerOptions copts,
const ir::train::TrainingInfo training_info 
)
explicit

Construct a new TrainingCompiler object for an nnpkg.

Parameters
[in]nnpkgnnpkg to compile
[in]coptscompiler options
[in]training_infotraining information

Definition at line 44 of file TrainingCompiler.cc.

46 : _model{nnpkg->primary_model()}, _options{copts}, _training_info{training_info}
47{
48 if (nnpkg->model_count() > 1)
49 throw std::runtime_error("TrainingCompiler does not support multiple models yet");
50
51 if (nnpkg->primary_model()->subgraphs_count() > 1)
52 throw std::runtime_error("TrainingCompiler does not support multiple subgraphs yet");
53}

◆ TrainingCompiler() [2/2]

onert::compiler::train::TrainingCompiler::TrainingCompiler ( void  )
delete

Construct a TrainingCompiler object.

◆ ~TrainingCompiler()

onert::compiler::train::TrainingCompiler::~TrainingCompiler ( )
default

Destroy the TrainingCompiler object.

Member Function Documentation

◆ compile()

std::shared_ptr< CompilerArtifact > onert::compiler::train::TrainingCompiler::compile ( void  )
virtual

Do compilation with the options.

Returns
std::shared_ptr<CompilerArtifact> Executors as a result of compilation

Implements onert::compiler::ICompiler.

Definition at line 55 of file TrainingCompiler.cc.

56{
57 /***************************************************
58 * Prepare compilation phase
59 ***************************************************/
60 if (!_options)
61 throw std::runtime_error{"Empty compile option"};
62
63 // Mode check
64 // TODO handle option for each model
65 if (_options->he_profiling_mode)
66 {
67 if (!_options->he_scheduler)
68 throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling.");
69
70 if (_options->executor != "Dataflow")
71 throw std::runtime_error("Profiling mode works only with 'Dataflow' executor");
72 }
73
74 _options->forceInternalOptions();
75 _options->verboseOptions();
76
77 auto custom_kernel_builder = _model->getKernelBuilder();
78
79 _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) {
80 auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
81 // Mandatory passes
82 compiler::pass::PassRunner{}
83 .append(std::make_unique<compiler::pass::ConstantOutputPass>(subg))
84 .append(std::make_unique<compiler::pass::OddOutputPass>(subg))
85 .run();
86
87 // Optimizations
88 compiler::pass::PassRunner{}
89 .append(std::make_unique<compiler::pass::UnusedOperandEliminationPass>(subg))
90 .run();
91 });
92
93 std::unordered_map<ir::SubgraphIndex, std::shared_ptr<ir::train::TrainableGraph>>
94 trainable_subgraphs;
95
96 if (_model->hasOnly<ir::Graph>())
97 {
98 // Create trainable subgraphs by copy and converting inference model
99 _model->iterate([&](const ir::SubgraphIndex &subg_index, const ir::IGraph &graph) {
100 const auto &subg = nnfw::misc::polymorphic_downcast<const ir::Graph &>(graph);
101 // Create TrainableGraph by copying Graph
102 auto trainable_subg = std::make_shared<ir::train::TrainableGraph>(subg);
103
104 // Convert operations to trainable operations
105 auto converter = TrainableOperationConverter{*trainable_subg, &_training_info};
106 ir::OperationIndex min_trainable_op_idx;
107 subg.operations().iterate(
108 [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &op) {
109 auto trainable_op = converter(op);
110 if (_training_info.getTrainableOps().find(op_index) !=
111 std::end(_training_info.getTrainableOps()))
112 {
113 trainable_op->enableWeightsUpdate();
114 if (op_index.value() < min_trainable_op_idx.value())
115 {
116 min_trainable_op_idx = op_index;
117 }
118 }
119 [[maybe_unused]] auto gen_index =
120 trainable_subg->replaceOperation(op_index, std::move(trainable_op));
121 assert(gen_index == op_index);
122 });
123
124 for (ir::OperationIndex idx{min_trainable_op_idx};
125 idx.value() < trainable_subg->operations().size(); idx++)
126 {
127 trainable_subg->enableBackward(idx);
128 }
129
130 trainable_subgraphs[subg_index] = std::move(trainable_subg);
131 });
132 }
133 else
134 {
135 // TODO Support models that have TrainableGraphs
136 throw std::runtime_error("TrainingCompiler: Invalid model");
137 }
138
139 // operation
140 _model.reset();
141
142 // TODO Handle dump level for each model
143 auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level);
144 onert::dumper::dot::DotDumper dot_dumper(dump_level);
145
146 for (const auto &[subg_index, subg] : trainable_subgraphs)
147 {
148 dot_dumper.dump(*subg, nnfw::misc::str("before_loss_insertion-", subg_index.value()));
149 }
150
151 // Apply pass for trainable subgraphs
152 for (auto &&[subg_index, trainable_subg] : trainable_subgraphs)
153 {
154 compiler::pass::PassRunner{}
155 .append(std::make_unique<train::pass::LossInsertionPass>(*trainable_subg, &_training_info,
156 subg_index))
157 .run();
158 }
159
160 for (const auto &[subg_index, subg] : trainable_subgraphs)
161 {
162 dot_dumper.dump(*subg, nnfw::misc::str("after_loss_insertion-", subg_index.value()));
163 }
164
165 for (auto &&[subg_index, subg] : trainable_subgraphs)
166 {
167 subg->updateGraphDependency();
168 subg->verify();
169
170 dot_dumper.dump(*subg,
171 nnfw::misc::str("after_initializing_training_usedefs-", subg_index.value()));
172 }
173
174 // Change input shape according to batch_size
175 for (auto &&pair : trainable_subgraphs)
176 {
177 auto trainable_subg = pair.second;
178
179 for (const auto &ind : trainable_subg->getInputs())
180 {
181 auto &input = trainable_subg->operands().at(ind);
182 auto new_shape = input.info().shape();
183 // TODO Consider batch size index
184 if (new_shape.dim(0) != 1)
185 throw std::runtime_error("the first dim is not 1. It is not supported yet.");
186 new_shape.dim(0) = _training_info.batchSize();
187 input.info().shape(new_shape);
188 }
189 }
190
191 /***************************************************
192 * Backend independent analysis & optimization phase
193 ***************************************************/
194 // Tracing context
195 auto tracing_ctx = std::make_unique<util::TracingCtx>();
196
197 // Lower: Assign backend
198 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::train::LoweredTrainableGraph>>
199 lowered_subgs;
200 {
201 for (auto &&[subg_index, trainable_subg] : trainable_subgraphs)
202 {
203 // Lower: Assign backend
204 lowered_subgs[subg_index] =
205 std::make_unique<compiler::train::LoweredTrainableGraph>(*trainable_subg, *_options);
206 // Set tracing_ctx for copied graph
207 tracing_ctx->setSubgraphIndex(&(lowered_subgs[subg_index]->graph()), subg_index.value());
208 }
209 }
210
211 for (const auto &[subg_index, lowered_subg] : lowered_subgs)
212 {
213 dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_subg-", subg_index.value()));
214 }
215
216 // Set operands' info for back propagation as default tensor info
217 for (const auto &pair : lowered_subgs)
218 {
219 auto lowered_subg = pair.second.get();
220 auto &tgraph = lowered_subg->trainable_graph();
221 tgraph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &obj) {
222 if (!obj.isConstant())
223 {
224 auto bwd_operand = std::make_unique<ir::Operand>(obj);
225 [[maybe_unused]] const auto gen_index =
226 tgraph.addBackwardOperand(index, std::move(bwd_operand));
227 assert(gen_index == index);
228 }
229 });
230 }
231
232 // Shape inference.
233 {
234 // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
235 // recursively
236 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
237 createStaticShapeInferers(lowered_subgs);
238
239 const auto primary_subg_idx = ir::SubgraphIndex{0};
240 inferers.at(primary_subg_idx)->infer();
241
242 for (const auto &pair_inferer : inferers)
243 {
244 const auto inferer = pair_inferer.second.get();
245 inferer->dump();
246 }
247
248 // NOTE StaticBackwardShapeInferer is allocated for each subgraph,
249 // so it does not support models that have controlflow operations yet.
250 for (auto &&pair : lowered_subgs)
251 {
252 auto &lowered_subg = pair.second;
253 auto inferer = std::make_unique<StaticBackwardShapeInferer>(lowered_subg.get());
254 inferer->infer();
255 inferer->dump();
256 }
257 }
258
259 // Shape validation
260 for (const auto &pair : lowered_subgs)
261 {
262 auto &lowered_subg = pair.second;
263 compiler::ShapeValidator{lowered_subg->graph()}();
264 }
265
266 // TODO Validate shapes of the tensors for back propagation
267
268 /*************************************************************
269 * Backend independent analysis & optimization phase finished
270 *************************************************************/
271 auto executors = std::make_shared<exec::train::TrainableExecutors>();
272 for (auto &&[subg_index, lowered_subg] : lowered_subgs)
273 {
274 auto const model_index = ir::ModelIndex{0};
275 auto const indexed_ranks = lowered_subg->indexed_ranks();
276
277 ir::OperationDumper dumper("Executor generation of Subgraph " +
278 std::to_string(subg_index.value()));
279 lowered_subg->graph().operations().iterate(
280 [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); });
281
282 ExecutorFactoryArgs args;
283 args.tracing_ctx = tracing_ctx.get();
284 args.options = _options;
285 args.model_index = model_index;
286 args.custom_kernel_builder = custom_kernel_builder;
287 auto executor = std::unique_ptr<exec::IExecutor>{
288 ExecutorFactory::get().create(std::move(lowered_subg), executors, args, _training_info)};
289 executor->setIndexedRanks(indexed_ranks);
290 executors->emplace(model_index, subg_index, std::move(executor));
291 }
292
293 /********************************
294 * Code generation phase finished
295 ********************************/
296 return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx));
297}
exec::IExecutor * create(std::unique_ptr< compiler::LoweredGraph > lowered_graph, const std::shared_ptr< exec::IExecutors > &executors, const ExecutorFactoryArgs &args)
static ExecutorFactory & get()
const std::set< OperationIndex > & getTrainableOps() const
T value() const
Return underlying value.
Definition Index.h:137
args
Definition infer.py:21
std::string str(Args &&...args)
::onert::util::Index< uint32_t, OperationIndexTag > OperationIndex
Definition Index.h:30
::onert::util::Index< uint16_t, ModelIndexTag > ModelIndex
Definition Index.h:42
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:33
::onert::util::Index< uint16_t, SubgraphIndexTag > SubgraphIndex
Definition Index.h:39
void forceInternalOptions()
Force default values of CompilerOptions for correct compilations.
void verboseOptions()
Print option value.
virtual void setIndexedRanks(std::shared_ptr< ir::OperationIndexMap< int64_t > >)=0
Set an ordering on operations.

References onert::ir::IOperation::accept(), onert::compiler::pass::PassRunner::append(), onert::ir::train::TrainingInfo::batchSize(), onert::compiler::ExecutorFactory::create(), onert::dumper::dot::DotDumper::dump(), onert::compiler::CompilerOptions::executor, onert::compiler::CompilerOptions::forceInternalOptions(), onert::compiler::ExecutorFactory::get(), onert::ir::train::TrainingInfo::getTrainableOps(), onert::compiler::CompilerOptions::graph_dump_level, onert::compiler::CompilerOptions::he_profiling_mode, onert::compiler::CompilerOptions::he_scheduler, onert::compiler::pass::PassRunner::run(), run(), onert::exec::IExecutor::setIndexedRanks(), nnfw::misc::str(), onert::util::Index< T, DummyTag >::value(), and onert::compiler::CompilerOptions::verboseOptions().


The documentation for this class was generated from the following files: