ONE - On-device Neural Engine
Loading...
Searching...
No Matches
DepthwiseConv2D.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "NodeExecution.h"
19
20#include "NodeDataImpl.h"
21#include "NodeDomain.h"
22#include "Validation.h"
23
29
30#include <cassert>
31#include <stdexcept>
32
33namespace
34{
35
41inline uint32_t compute_out_size(uint32_t image_size, uint32_t whole_pad, uint32_t filter_size,
42 uint32_t stride)
43{
44 assert((image_size + whole_pad - filter_size) % stride == 0);
45 return (image_size + whole_pad - filter_size) / stride + 1;
46}
47
54
60template <typename RET_T, typename IFM_T, typename KER_T>
61Buffer<RET_T> calc_dw_conv2d(const loco::DepthwiseConv2D *dw_conv2d, const Buffer<IFM_T> *ifm_buf,
62 const Buffer<KER_T> *ker_buf)
63{
64 auto ifm_shape = ifm_buf->shape();
65 auto ker_shape = ker_buf->shape();
66
67 locomotiv::validate(ifm_shape.rank() == 4, "ifm rank must be 4");
68 locomotiv::validate(ker_shape.rank() == 4, "depthwise filter rank must be 4");
69 locomotiv::validate(ifm_shape.dim(3 /* of NHWC */) == ker_shape.dim(2 /* of HWCM */),
70 "channel value mismatch"); // should have same channel values
71
72 const uint32_t ifm_height = ifm_shape.dim(1);
73 const uint32_t ifm_width = ifm_shape.dim(2);
74
75 const uint32_t ker_height = ker_shape.dim(0);
76 const uint32_t ker_width = ker_shape.dim(1);
77
78 const uint32_t stride_width = dw_conv2d->stride()->horizontal();
79 const uint32_t stride_height = dw_conv2d->stride()->vertical();
80
81 // TODO Enable dilations. Let's set these to 1 for now.
82 const uint32_t dilation_width_factor = 1;
83 const uint32_t dilation_height_factor = 1;
84
85 const uint32_t pad_top = dw_conv2d->pad()->top();
86 const uint32_t pad_bottom = dw_conv2d->pad()->bottom();
87
88 const uint32_t pad_left = dw_conv2d->pad()->left();
89 const uint32_t pad_right = dw_conv2d->pad()->right();
90
91 const uint32_t ofm_height =
92 compute_out_size(ifm_height, pad_top + pad_bottom, ker_height, stride_height);
93 const uint32_t ofm_width =
94 compute_out_size(ifm_width, pad_left + pad_right, ker_width, stride_width);
95
96 const uint32_t batches = ifm_shape.dim(0);
97 const uint32_t ifm_depth = ifm_shape.dim(3);
98 const uint32_t multiplier = ker_shape.dim(3);
99 const uint32_t ofm_depth = ifm_depth * multiplier;
100
101 Shape ofm_shape{batches, ofm_height, ofm_width, ofm_depth};
102 auto ofm_buf = make_buffer<RET_T, LexicalLayout>(ofm_shape);
103
104 for (uint32_t batch = 0; batch < batches; ++batch)
105 {
106 for (uint32_t ofm_y = 0; ofm_y < ofm_height; ++ofm_y)
107 {
108 for (uint32_t ofm_x = 0; ofm_x < ofm_width; ++ofm_x)
109 {
110 for (uint32_t ch = 0; ch < ifm_depth; ++ch)
111 {
112 for (uint32_t nth = 0; nth < multiplier; nth++)
113 {
114 const int in_x_origin = (ofm_x * stride_width) - pad_left;
115 const int in_y_origin = (ofm_y * stride_height) - pad_top;
116 float total = 0.f;
117 for (uint32_t ker_y = 0; ker_y < ker_height; ++ker_y)
118 {
119 for (uint32_t ker_x = 0; ker_x < ker_width; ++ker_x)
120 {
121 const int in_x = in_x_origin + dilation_width_factor * ker_x;
122 const int in_y = in_y_origin + dilation_height_factor * ker_y;
123 // If the location is outside the bounds of the input image,
124 // use zero as a default value.
125 if ((in_x >= 0) && ((unsigned)in_x < ifm_width) && (in_y >= 0) &&
126 ((unsigned)in_y < ifm_height))
127 {
128 auto ifm_value = ifm_buf->at(Index({batch, (unsigned)in_y, (unsigned)in_x, ch}));
129 auto ker_value = ker_buf->at(Index({ker_y, ker_x, ch, nth}));
130 total += (ifm_value * ker_value);
131 }
132 }
133 }
134 uint32_t ofm_channel = ch * multiplier + nth;
135 ofm_buf.at(Index({batch, ofm_y, ofm_x, ofm_channel})) = total;
136 }
137 }
138 }
139 }
140 }
141 return ofm_buf;
142}
143
144} // namespace
145
146namespace
147{
148
149using namespace locomotiv;
150
151void execute_node(loco::DepthwiseConv2D *dw_conv2d)
152{
153 auto ifm_data = annot_data(dw_conv2d->ifm());
154 auto ker_data = annot_data(dw_conv2d->ker());
155
156 validate(ifm_data, "Can't find input data of DepthwiseConv2D");
157 validate(ifm_data->shape()->rank() == 4, "ifm rank must be 4");
158
159 validate(ker_data, "Can't find kernel data of DepthwiseConv2D");
160 validate(ker_data->shape()->rank() == 4, "Kernel rank must be 4");
161
163 "IFM of DepthwiseConv2D is not feature");
165 "Kernel of DepthwiseConv2D is not depthwise filter");
166
167 std::unique_ptr<NodeData> dw_conv2d_result = nullptr;
168
169 if (ifm_data->dtype() == loco::DataType::FLOAT32 && ker_data->dtype() == loco::DataType::FLOAT32)
170 {
171 auto ifm_buf = ifm_data->as_f32_bufptr();
172 auto ker_buf = ker_data->as_f32_bufptr();
173
174 auto dw_conv2d_buf = calc_dw_conv2d<float, float, float>(dw_conv2d, ifm_buf, ker_buf);
175
176 dw_conv2d_result = make_data(dw_conv2d_buf);
177 }
178 else
179 throw std::runtime_error("NYI for these DataTypes");
180
181 assert(dw_conv2d_result != nullptr);
182
183 annot_data(dw_conv2d, std::move(dw_conv2d_result));
185}
186
187} // namespace
188
189namespace locomotiv
190{
191
192void NodeExecution::execute(loco::DepthwiseConv2D *dw_conv2d) { execute_node(dw_conv2d); }
193
194} // namespace locomotiv
Depthwise 2D Convolution.
Definition Nodes.h:582
Node * ifm(void) const
Definition Nodes.h:584
const Stride< 2 > * stride(void) const
Definition Nodes.h:595
const Padding2D * pad(void) const
Definition Nodes.h:591
Node * ker(void) const
Definition Nodes.h:587
uint32_t left(void) const
Definition Padding2D.h:49
uint32_t top(void) const
Definition Padding2D.h:41
uint32_t bottom(void) const
Definition Padding2D.h:45
uint32_t right(void) const
Definition Padding2D.h:53
uint32_t horizontal(void) const
Definition Stride.h:40
uint32_t vertical(void) const
Definition Stride.h:36
bool validate(Code *code)
void validate(bool true_cond, const std::string &&exception_msg)
Definition Validation.h:26
void annot_domain(loco::Node *node, const loco::Domain &domain)
Wrapper to annotate domain to node. Cannot annotate unknown domain.
std::unique_ptr< NodeData > make_data(const NodeData::Buffer< DT > &buffer)
Copy buffer to make NodeData.
Buffer< T > make_buffer(const Shape &shape)
Definition Buffer.h:47
Definition Shape.h:28