ONE - On-device Neural Engine
Loading...
Searching...
No Matches
conv2d.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <nest/Module.h>
18
19int main(int, char **)
20{
21 // This example shows how to specify convolution with IFM(1x3x3) and Kernel(1x1x3x3) with nest
22 // - STRIDE is 1, and there is no padding
23 //
24 // The below code corresponds to the following nest DSL code:
25 // ----------------------------------------------------------------------------------------------
26 // Domain ofm(1, 1, 1)
27 // Domain ifm(1, 3, 3)
28 // Domain ker(1, 1, 3, 3)
29 //
30 // Var ofm_ch : { min = 0, max = 1 }
31 // Var ofm_row : { min = 0, max = 1 }
32 // Var ofm_col : { min = 0, max = 1 }
33 // Var ker_ch : { min = 0, max = 1 }
34 // Var ker_row : { min = 0, max = 3 }
35 // Var ker_col : { min = 0, max = 3 }
36 //
37 // PUSH ifm(ker_ch, ker_row, ker_col) * ker(ofm_ch, ker_ch, ofm_row + ker_row, ofm_col + ker_col)
38 // RET ofm(ofm_ch, ofm_row, ofm_col)
39 // ----------------------------------------------------------------------------------------------
40 //
41 // The first part declares Domain(s) which corresponds to a multi-dimensional array in C-style
42 // (without type). For example, 'Domain ofm(1, 3, 3)' corresponds to the
43 // following C array declaration.
44 // float ofm[1][3][3];
45 // (Here we assume that domain type is 'float')
46 //
47 // The second part declares Var(s) which serves as a loop iteration variable. Basically, each
48 // variable emits one for loop and these loops are nested. As there are 6 variables in the above
49 // example, there will be 6 nested-loops.
50 //
51 // Each variable has a corresponding bound, and the bound of each variable states the starting /
52 // termination condition. For example, 'Var ofm_ch : { min = 0, max = 1 }' will introduce the
53 // following for loop:
54 // ----------------------------------------------------------------------------------------------
55 // for (int ofm_ch = 0; ofm_ch < 1; ++ofm_ch) { ... }
56 // ----------------------------------------------------------------------------------------------
57 //
58 // The last part declares statement(s) which state the computation performed inside these nested
59 // loops. Nest is stack-based. There is a virtual stack inside nested loop, and the evaluation of
60 // each statement will update this stack.
61 //
62 // Each nest code has one return statement (RET). This return statement specifies where to write
63 // the computed result.
64 //
65 // PUSH 'expr' statement evaluates an arithmetic expression (specified by 'expr') and pushes the
66 // numeric result to the stack. When PUSH statement evaluates an arithmetic expression, variables
67 // that do not appear in RET statement are treated as reduction variables. For example,
68 // ker_ch, ker_row, and ker_col do not appear in RET statement. So, PUSH '...' statement in the
69 // above example corresponds to the following nested loops:
70 // ----------------------------------------------------------------------------------------------
71 // float value = 0.0f;
72 //
73 // for (int ker_ch = 0; ker_ch < 1; ++ker_ch) {
74 // for (int ker_row = 0; ker_row < 3; ++ker_row) {
75 // for (int ker_col = 0; ker_col < 3; ++ker_col) {
76 // float ifm_value = ifm[ker_ch][ofm_row + ker_row][ofm_col + ker_col];
77 // float ker_value = ker[ofm_ch][ker_ch][ker_row][ker_col];
78 // value += ifm_value * ker_value;
79 // }
80 // }
81 // }
82 // ----------------------------------------------------------------------------------------------
83 //
84 // In summary, the above nest example corresponds to the following 2D convolution:
85 // ----------------------------------------------------------------------------------------------
86 // float ofm[1][1][1];
87 // float ifm[1][3][3];
88 // float ker[1][1][3][3];
89 //
90 // for (int ofm_ch = 0; ofm_ch < 1; ++ofm_ch) {
91 // for (int ofm_row = 0; ofm_row < 1; ++ofm_row) {
92 // for (int ofm_col = 0; ofm_col < 1; ++ofm_col) {
93 // float value = 0.0f;
94 //
95 // for (int ker_ch = 0; ker_ch < 1; ++ker_ch) {
96 // for (int ker_row = 0; ker_row < 3; ++ker_row) {
97 // for (int ker_col = 0; ker_col < 3; ++ker_col) {
98 // float ifm_value = ifm[ker_ch][ofm_row + ker_row][ofm_col + ker_col];
99 // float ker_value = ker[ofm_ch][ker_ch][ker_row][ker_col];
100 // value += ifm_value * ker_value;
101 // }
102 // }
103 // }
104 //
105 // ofm[ofm_ch][ofm_col][ofm_row] = value;
106 // }
107 // }
108 // }
109 // ----------------------------------------------------------------------------------------------
110 //
112
113 //
114 // Domains
115 //
116 auto ofm = m.domain().make({1 /*C*/, 1 /*H*/, 1 /*W*/});
117 auto ifm = m.domain().make({1 /*C*/, 3 /*H*/, 3 /*W*/});
118 auto ker = m.domain().make({1 /*N*/, 1 /*C*/, 3 /*H*/, 3 /*W*/});
119
120 //
121 // Variables
122 //
123 auto ofm_ch = m.var().make();
124 auto ofm_row = m.var().make();
125 auto ofm_col = m.var().make();
126
127 auto ker_ch = m.var().make();
128 auto ker_row = m.var().make();
129 auto ker_col = m.var().make();
130
131 // Declare the bound of each variables
132 using nest::Bound;
133
134 m.var().bound(ofm_ch) = Bound{0, 1};
135 m.var().bound(ofm_row) = Bound{0, 1};
136 m.var().bound(ofm_col) = Bound{0, 1};
137
138 m.var().bound(ker_ch) = Bound{0, 1};
139 m.var().bound(ker_row) = Bound{0, 3};
140 m.var().bound(ker_col) = Bound{0, 3};
141
142 //
143 // Statement
144 //
145 auto ifm_value = ifm(ker_ch, ofm_row + ker_row, ofm_col + ker_col);
146 auto ker_value = ker(ofm_ch, ker_ch, ker_row, ker_col);
147
148 m.push(ifm_value * ker_value);
149 m.ret(ofm(ofm_ch, ofm_row, ofm_col));
150
151 return 0;
152}
int main(void)
Domain make(std::initializer_list< uint32_t > dims)
DomainContext & domain(void)
Definition Module.h:44