struct [Name] { ...
[Name]() // constructor { ... }
~[Name]() // destructor { ... } };
146{
148 auto d = _code->
data();
150
151 NetworkStruct network;
152 InvokeFunction invoke;
154
155 auto data_exp = [
this](
const GlobalOffset &off) {
return pp::fmt(_varname,
" + ", off); };
156
157
158 std::map<const ANNBinder *, SubnetInfo> subnet_ctx;
159
179 for (uint32_t n = 0; n < ann_ctx->count(); ++n)
180 {
181 SubnetStructBuilder builder;
182
183 auto subnet_binder = ann_ctx->nth(n);
184 auto subnet_struct_name = pp::fmt("Subnet_", subnet_ctx.size());
185 auto subnet_field_name = pp::fmt("_subnet_", subnet_ctx.size());
186
187
190 {
193 auto base_exp = pp::fmt("reinterpret_cast<const void *>(", data_exp(off), ")");
194 auto size_exp = pp::fmt(
size);
195
196 builder.expr(
info, base_exp, size_exp);
197 }
198 };
199 subnet_binder->module()->operand()->each(emit_weight);
200
201 auto subnet_struct_content = builder.build(subnet_binder);
202
203
204 internal.
append(
"struct ", subnet_struct_name);
207
208 internal.
append(subnet_struct_content->def());
209
210 internal.
append(subnet_struct_name,
"()");
213 internal.
append(subnet_struct_content->ctor());
216
217 internal.
append(
"~", subnet_struct_name,
"()");
220 internal.
append(subnet_struct_content->dtor());
223
226
227
228 network.def.append(subnet_struct_name, " ", subnet_field_name, ";");
229
230
231 SubnetInfo subnet_info;
232
233 subnet_info.struct_name = subnet_struct_name;
234 subnet_info.compilation_field = subnet_struct_content->compilation();
235 subnet_info.field_name = subnet_field_name;
236
237 assert(subnet_ctx.find(subnet_binder) == subnet_ctx.end());
238 subnet_ctx[subnet_binder] = subnet_info;
239 }
240
241 MemoryContext mem;
242
243
244 for (uint32_t n = 0; n <
m->input()->
size(); ++n)
245 {
246 mem.base(
m->input()->at(n)->bag(), pp::fmt(
"net->inputs[", n,
"].ptr"));
247 mem.size(
m->input()->at(n)->bag(), pp::fmt(
"net->inputs[", n,
"].len"));
248 }
249
250
251 for (uint32_t n = 0; n <
m->output()->
size(); ++n)
252 {
253 mem.base(
m->output()->at(n)->bag(), pp::fmt(
"net->outputs[", n,
"].ptr"));
254 mem.size(
m->output()->at(n)->bag(), pp::fmt(
"net->outputs[", n,
"].len"));
255 }
256
257
258
259 for (uint32_t n = 0; n <
m->entity()->bag()->
size(); ++n)
260 {
261 auto bag =
m->entity()->bag()->at(n);
262
263 if (!d->allocated(bag))
264 {
265
266 continue;
267 }
268
269
271
272 auto base_expr = data_exp(
offset);
273 auto size_expr = pp::fmt(bag->size() * sizeof(float));
274
275 mem.base(bag, base_expr);
276 mem.size(bag, size_expr);
277 }
278
279
280 for (const auto &bag : hosted(_code))
281 {
282
283 if (mem.member(bag))
284 {
285 continue;
286 }
287
288 auto name = invoke.local();
289
290 invoke.head.append("auto ", name, " = new uint8_t[", bag->size() * sizeof(float), "];");
291 invoke.tail.append("delete[] ", name, ";");
292
293 mem.base(bag, name);
294 mem.size(bag, pp::fmt(bag->size() * sizeof(float)));
295 }
296
297
298 SubnetBlockCompiler subnet_compiler{mem};
299
300 for (auto it = subnet_ctx.begin(); it != subnet_ctx.end(); ++it)
301 {
302
303 const auto &
info = it->second;
304 subnet_compiler.bind(it->first, pp::fmt(
"net->",
info.field_name,
".",
info.compilation_field));
305 }
306
307 HostBlockCompiler host_compiler{mem};
308
309 for (
auto blk =
m->block()->head(); blk; blk = blk->next())
310 {
311 invoke.body.append("{");
312 invoke.body.indent();
313
314 if (auto binder = ann_ctx->find(blk))
315 {
316
317 auto lines = subnet_compiler.compile(binder);
318 invoke.body.append(*lines);
319 }
320 else
321 {
322
323 auto lines = host_compiler.compile(blk);
324 invoke.body.append(*lines);
325 }
326
327 invoke.body.unindent();
328 invoke.body.append("}");
329 }
330
331
332
333
334 const std::string name{"Network"};
335
337 {
338
339 includes.
append(
"#include <NeuralNetworks.h>");
341
342 includes.
append(
"#include <cstdint>");
343 includes.
append(
"#include <cassert>");
344 includes.
append(
"#include <array>");
345 }
346
348 {
349 net_def.
append(
"struct ", name,
" {");
351 net_def.
append(
"struct Shape { uint32_t rank; const uint32_t *dims; };");
352 net_def.
append(
"struct Input {");
354 net_def.
append(
"const char *name;");
355 net_def.
append(
"const uint8_t *ptr;");
356 net_def.
append(
"unsigned len;");
357 net_def.
append(
"Shape shape;");
360 net_def.
append(
"struct Output {");
362 net_def.
append(
"const char *name;");
363 net_def.
append(
"uint8_t *ptr;");
364 net_def.
append(
"unsigned len;");
365 net_def.
append(
"Shape shape;");
369 net_def.
append(name,
"();");
370 net_def.
append(
"~", name,
"();");
371
373 net_def.
append(network.def);
375
376 net_def.
append(
"std::array<Input, ",
m->input()->size(),
"> inputs;");
377 net_def.
append(
"std::array<Output, ",
m->output()->size(),
"> outputs;");
378
381 }
382
384 {
385 net_ctor.
append(
"Network::Network() {");
387
388
389 for (uint32_t n = 0; n <
m->input()->size(); ++n)
390 {
391 auto input =
m->input()->at(n);
393
395 auto name_exp = pp::fmt("reinterpret_cast<const char *>(", data_exp(name_off), ")");
397 auto dims_exp = pp::fmt("reinterpret_cast<const unsigned *>(", data_exp(dims_off), ")");
398
399 net_ctor.
append(
"inputs.at(", n,
").name = ", name_exp,
";");
400 net_ctor.
append(
"inputs.at(", n,
").shape.rank = ", dims.size(),
";");
401 net_ctor.
append(
"inputs.at(", n,
").shape.dims = ", dims_exp,
";");
402 }
403
404
405 for (uint32_t n = 0; n <
m->output()->
size(); ++n)
406 {
407 auto output =
m->output()->at(n);
409
411 auto name_exp = pp::fmt("reinterpret_cast<const char *>(", data_exp(name_off), ")");
413 auto dims_exp = pp::fmt("reinterpret_cast<const unsigned *>(", data_exp(dims_off), ")");
414
415 net_ctor.
append(
"outputs.at(", n,
").name = ", name_exp,
";");
416 net_ctor.
append(
"outputs.at(", n,
").shape.rank = ", dims.size(),
";");
417 net_ctor.
append(
"outputs.at(", n,
").shape.dims = ", dims_exp,
";");
418 }
419
420
423 }
424
426 {
427 net_dtor.
append(
"Network::~Network() {");
429
432 }
433
435
438 source.
append(
"extern uint8_t ", _varname,
"[];");
440
441 source.
append(
"namespace");
444 source.
append(
"} // namespace");
451
453 source.
append(name,
" *", name,
"_construct() { return new ", name,
"{}; }");
454 source.
append(
"void ", name,
"_destruct(", name,
" *net) { delete net; }");
455
457
458
459 source.
append(
"unsigned ", name,
"_input_count(const ", name,
" *net) {");
461 source.
append(
"return net->inputs.size();");
464
466
467
468 source.
append(
"const char *", name,
"_input_name(const ", name,
" *net, unsigned n) {");
470 source.
append(
"return net->inputs.at(n).name;");
473
474
475 source.
append(
"unsigned ", name,
"_input_rank(const ", name,
" *net, unsigned n) {");
477 source.
append(
"return net->inputs.at(n).shape.rank;");
480
481
482 source.
append(
"unsigned ", name,
"_input_dim(const ", name,
" *net, unsigned n, unsigned axe)");
485 source.
append(
"return net->inputs.at(n).shape.dims[axe];");
488
489
490 source.
append(
"void ", name,
"_input_bind(", name,
491 " *net, unsigned n, const void *ptr, unsigned len) {");
493 source.
append(
"net->inputs.at(n).ptr = reinterpret_cast<const uint8_t *>(ptr);");
494 source.
append(
"net->inputs.at(n).len = len;");
497
499
500
501 source.
append(
"unsigned ", name,
"_output_count(const ", name,
" *net) {");
503 source.
append(
"return net->outputs.size();");
506
508
509
510 source.
append(
"const char *", name,
"_output_name(const ", name,
" *net, unsigned n) {");
512 source.
append(
"return net->outputs.at(n).name;");
515
516
517 source.
append(
"unsigned ", name,
"_output_rank(const ", name,
" *net, unsigned n) {");
519 source.
append(
"return net->outputs.at(n).shape.rank;");
522
523
524 source.
append(
"unsigned ", name,
"_output_dim(const ", name,
" *net, unsigned n, unsigned axe)");
527 source.
append(
"return net->outputs.at(n).shape.dims[axe];");
530
531
532 source.
append(
"void ", name,
"_output_bind(", name,
533 " *net, unsigned n, void *ptr, unsigned len) {");
535 source.
append(
"net->outputs.at(n).ptr = reinterpret_cast<uint8_t *>(ptr);");
536 source.
append(
"net->outputs.at(n).len = len;");
539
541
542 source.
append(
"void ", name,
"_invoke(", name,
" *net) {");
544 source.
append(invoke.head);
545 source.
append(invoke.body);
546 source.
append(invoke.tail);
549
550 os << source;
551}
__global uchar * offset(const Image *img, int x, int y)
volatile const char info[]
Dims< uint32_t > as_dims(const nncc::core::ADT::tensor::Shape &)
coco::Module * module(void) const
coco::Data * data(void) const
static GlobalOffset data_offset(const ann::Operand *)
static GlobalOffset name_offset(const coco::Input *)
static GlobalOffset dims_offset(const coco::Input *)
static const ANNContext * context(const coco::Module *m)