ONE - On-device Neural Engine
Loading...
Searching...
No Matches
caffeimport::ScaleBuilder Class Referencefinal

#include <Scale.h>

Collaboration diagram for caffeimport::ScaleBuilder:

Public Member Functions

void build (const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const override
 
- Public Member Functions inherited from caffeimport::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

Definition at line 27 of file Scale.h.

Member Function Documentation

◆ build()

void caffeimport::ScaleBuilder::build ( const ::caffe::LayerParameter &  layer,
GraphBuilderContext context 
) const
overridevirtual

Implements caffeimport::GraphBuilder.

Definition at line 32 of file Scale.cpp.

33{
34 coco::Module *module = context->module();
35 coco::Data *data = context->data();
36 coco::Block *blk = context->block();
37 std::map<std::string, tensor::Shape> &shape_ctx = context->shape_ctx();
38 std::map<std::string, coco::Bag *> &bag_ctx = context->bag_ctx();
39 WeightContext &weight_ctx = context->weight_ctx();
40
41 // TODO Support Scale layer with 2 bottoms
42 assert(layer.bottom().size() == 1);
43 assert(layer.top().size() == 1);
44
45 assert(layer.has_scale_param());
46 const auto &param = layer.scale_param();
47
48 assert(param.axis() == 1);
49 assert(!param.has_num_axes());
50
51 assert(weight_ctx.blob_count(layer.name()) >= 1);
52
53 // NOTE The shape of "Scale" output is same as that of its input
54 // NOTE The current implementation assumes that input/output is of feature type
55 // TODO Support generic tensor arguments
56 auto shape = shape_ctx.at(layer.bottom(0));
57
58 coco::Bag *last_bag = bag_ctx.at(layer.bottom(0));
59
60 // Create channel-wise multiplication
61 {
62 auto in_bag = last_bag;
63 auto in_obj = module->entity()->object()->create<coco::FeatureObject>();
64
65 in_obj->bag(in_bag);
67
68 auto factor_bag = module->entity()->bag()->create(num_elements(shape));
69 auto factor_obj = module->entity()->object()->create<coco::FeatureObject>();
70
71 factor_obj->bag(factor_bag);
72 factor_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(shape)));
73
74 auto out_bag = module->entity()->bag()->create(num_elements(shape));
75 auto out_obj = module->entity()->object()->create<coco::FeatureObject>();
76
77 out_obj->bag(out_bag);
79
80 auto mul_op = op_builder(module).load(factor_obj).load(in_obj).mul().pop();
81 auto mul_ins = instr_builder(module).eval(out_obj, mul_op);
82
83 blk->instr()->append(mul_ins);
84
85 // Fill "factor" data
86 {
87 data->f32()->allocate(factor_bag);
88
89 auto span = data->f32()->weight(factor_bag);
90 auto blob = weight_ctx.blob_get(layer.name(), 0);
91
92 for (uint32_t ch = 0; ch < factor_obj->shape().depth(); ++ch)
93 {
94 span[ch] = blob->data(ch);
95 }
96 }
97
98 // Update "last_bag"
99 last_bag = out_bag;
100 }
101
102 assert(last_bag != nullptr);
103
104 // Create bias addition (as channel-wise addition)
105 if (param.bias_term())
106 {
107 assert(weight_ctx.blob_count(layer.name()) >= 2);
108
109 auto in_bag = last_bag; /* Use the output of the last computation as an input */
110 auto in_obj = module->entity()->object()->create<coco::FeatureObject>();
111
112 in_obj->bag(in_bag);
114
115 auto bias_bag = module->entity()->bag()->create(num_elements(shape));
116 auto bias_obj = module->entity()->object()->create<coco::FeatureObject>();
117
118 bias_obj->bag(bias_bag);
119 bias_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(shape)));
120
121 auto out_bag = module->entity()->bag()->create(num_elements(shape));
122 auto out_obj = module->entity()->object()->create<coco::FeatureObject>();
123
124 out_obj->bag(out_bag);
126
127 auto add_op = op_builder(module).load(bias_obj).load(in_obj).add().pop();
128 auto add_ins = instr_builder(module).eval(out_obj, add_op);
129
130 blk->instr()->append(add_ins);
131
132 // Fill bias data
133 {
134 data->f32()->allocate(bias_bag);
135
136 auto bias_span = data->f32()->weight(bias_bag);
137 auto bias_blob = weight_ctx.blob_get(layer.name(), 1);
138
139 for (uint32_t ch = 0; ch < bias_obj->shape().depth(); ++ch)
140 {
141 bias_span[ch] = bias_blob->data(ch);
142 }
143 }
144
145 // Update "last_bag"
146 last_bag = out_bag;
147 }
148
149 // Update bag and shape context
150 {
151 const auto &out_name = layer.top(0);
152 const auto &out_bag = last_bag;
153 const auto &out_shape = shape;
154
155 bag_ctx[out_name] = out_bag;
156 shape_ctx[out_name] = out_shape;
157 }
158}
OpBuilder op_builder(coco::Module *m)
Definition IRBuilder.h:144
InstrBuilder instr_builder(coco::Module *m)
Definition IRBuilder.h:174
coco::Eval * eval(coco::Object *out, coco::Op *op) const
Create "Eval" instruction with a given "Object" and "Op".
Definition IRBuilder.h:162
OpBuilder & load(coco::Object *obj)
Create "Load" op and push it onto the internal stack.
Definition IRBuilder.h:70
OpBuilder & mul(void)
Create "Mul" op and push it onto the internal stack.
Definition IRBuilder.h:100
coco::Op * pop(void)
Pop op from the internal stack.
Definition IRBuilder.h:116
OpBuilder & add(void)
Create "Add" op and push it onto the internal stack.
Definition IRBuilder.h:84
A collection of (abstracted) elements of the same type.
Definition Bag.h:48
A unit of (grouped) instructions.
Definition Block.h:40
InstrList * instr(void)
Definition Block.h:65
void append(Child *child)
static std::unique_ptr< BCHW > create(const nncc::core::ADT::feature::Shape &shape)
static std::unique_ptr< BC > create(const nncc::core::ADT::feature::Shape &shape)
Top-level element of coco IR which represents a neural network.
Definition Module.h:34
nncc::core::ADT::feature::Shape as_feature_shape(const nncc::core::ADT::tensor::Shape &)
Definition caffe.cpp:54
Core coco entity for constant weights.
Definition Data.h:31

References OpBuilder::add(), coco::DLinkedList< Child, Parent >::Head::append(), morph::caffe::as_feature_shape(), caffeimport::GraphBuilderContext::bag_ctx(), caffeimport::WeightContext::blob_count(), caffeimport::WeightContext::blob_get(), caffeimport::GraphBuilderContext::block(), coco::FeatureLayouts::BCHW::create(), coco::FeatureLayouts::BC::create(), caffeimport::GraphBuilderContext::data(), InstrBuilder::eval(), coco::Block::instr(), instr_builder(), OpBuilder::load(), OpBuilder::mul(), op_builder(), OpBuilder::pop(), caffeimport::GraphBuilderContext::shape_ctx(), and caffeimport::GraphBuilderContext::weight_ctx().


The documentation for this class was generated from the following files: