ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Conv2D.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "NodeExecution.h"
19
20#include "NodeDataImpl.h"
21#include "NodeDomain.h"
22#include "Validation.h"
23
29
30#include <cassert>
31#include <stdexcept>
32
33namespace
34{
35// image size includes padding.
36inline uint32_t compute_out_size(uint32_t image_size, uint32_t filter_size, uint32_t stride)
37{
38 assert((image_size + stride - filter_size) % stride == 0);
39 return (image_size + stride - filter_size) / stride;
40}
41
48
53template <typename RET_T, typename IFM_T, typename FIL_T>
54Buffer<RET_T> calc_conv2D(const loco::Conv2D *conv2d, const Buffer<IFM_T> *input_buf,
55 const Buffer<FIL_T> *filter_buf)
56{
57 auto input_shape = input_buf->shape();
58 auto filter_shape = filter_buf->shape();
59
60 locomotiv::validate(input_shape.rank() == 4, "ifm rank must be 4");
61 locomotiv::validate(filter_shape.rank() == 4, "filter rank must be 4");
62 locomotiv::validate(input_shape.dim(3) == filter_shape.dim(3),
63 "channel value mismatch"); // should have same channel values
64
65 const uint32_t input_height = input_shape.dim(1);
66 const uint32_t input_width = input_shape.dim(2);
67
68 const uint32_t filter_height = filter_shape.dim(1);
69 const uint32_t filter_width = filter_shape.dim(2);
70
71 const uint32_t stride_width = conv2d->stride()->horizontal();
72 const uint32_t stride_height = conv2d->stride()->vertical();
73
74 // TODO Enable dilations. Let's set these to 1 for now.
75 const uint32_t dilation_width_factor = 1;
76 const uint32_t dilation_height_factor = 1;
77
78 const uint32_t pad_top = conv2d->pad()->top();
79 const uint32_t pad_bottom = conv2d->pad()->bottom();
80
81 const uint32_t pad_left = conv2d->pad()->left();
82 const uint32_t pad_right = conv2d->pad()->right();
83
84 const uint32_t output_height =
85 compute_out_size(input_height + pad_top + pad_bottom, filter_height, stride_height);
86 const uint32_t output_width =
87 compute_out_size(input_width + pad_left + pad_right, filter_width, stride_width);
88
89 const uint32_t batches = input_shape.dim(0);
90 const uint32_t input_depth = input_shape.dim(3);
91 const uint32_t output_depth = filter_shape.dim(0);
92
93 Shape output_shape{batches, output_height, output_width, output_depth};
94 auto output_buf = make_buffer<RET_T, LexicalLayout>(output_shape);
95
96 for (uint32_t batch = 0; batch < batches; ++batch)
97 {
98 for (uint32_t out_y = 0; out_y < output_height; ++out_y)
99 {
100 for (uint32_t out_x = 0; out_x < output_width; ++out_x)
101 {
102 for (uint32_t out_channel = 0; out_channel < output_depth; ++out_channel)
103 {
104 const int in_x_origin = (out_x * stride_width) - pad_left;
105 const int in_y_origin = (out_y * stride_height) - pad_top;
106
107 RET_T total = static_cast<RET_T>(0);
108
109 for (uint32_t filter_y = 0; filter_y < filter_height; ++filter_y)
110 {
111 for (uint32_t filter_x = 0; filter_x < filter_width; ++filter_x)
112 {
113 for (uint32_t in_channel = 0; in_channel < input_depth; ++in_channel)
114 {
115 const int32_t in_x = in_x_origin + dilation_width_factor * filter_x;
116 const int32_t in_y = in_y_origin + dilation_height_factor * filter_y;
117
118 // If the location is outside the bounds of the input image,
119 // use zero as a default value.
120 if ((in_x >= 0) && ((unsigned)in_x < input_width) && (in_y >= 0) &&
121 ((unsigned)in_y < input_height))
122 {
123 auto input_value =
124 input_buf->at(Index({batch, (unsigned)in_y, (unsigned)in_x, in_channel}));
125 auto filter_value =
126 filter_buf->at(Index({out_channel, filter_y, filter_x, in_channel}));
127 total += (input_value * filter_value);
128 }
129 }
130 }
131 }
132 output_buf.at(Index({batch, out_y, out_x, out_channel})) = total;
133 }
134 }
135 }
136 }
137 return output_buf;
138}
139
140} // namespace
141
142namespace
143{
144
145using namespace locomotiv;
146
147void execute_node(loco::Conv2D *conv2d)
148{
149 auto ifm_data = annot_data(conv2d->ifm());
150 auto ker_data = annot_data(conv2d->ker());
151
152 validate(ifm_data, "Can't find input data of Conv2D");
153 validate(ifm_data->shape()->rank() == 4, "ifm rank must be 4");
154
155 validate(ker_data, "Can't find kernel data of Conv2D");
156 validate(ker_data->shape()->rank() == 4, "Kernel rank must be 4");
157
158 validate(annot_domain(conv2d->ifm()) == loco::Domain::Feature, "IFM of Conv2D is not feature");
159 validate(annot_domain(conv2d->ker()) == loco::Domain::Filter, "Kernel of Conv2D is not filter");
160
161 std::unique_ptr<NodeData> conv2d_result = nullptr;
162
163 if (ifm_data->dtype() == loco::DataType::FLOAT32 && ker_data->dtype() == loco::DataType::FLOAT32)
164 {
165 auto ifm_buf = ifm_data->as_f32_bufptr();
166 auto ker_buf = ker_data->as_f32_bufptr();
167
168 auto conv2d_buf = calc_conv2D<float, float, float>(conv2d, ifm_buf, ker_buf);
169
170 conv2d_result = make_data(conv2d_buf);
171 }
172 else
173 throw std::runtime_error("NYI for these DataTypes");
174
175 assert(conv2d_result != nullptr);
176
177 annot_data(conv2d, std::move(conv2d_result));
179}
180
181} // namespace
182
183namespace locomotiv
184{
185
186void NodeExecution::execute(loco::Conv2D *conv2d) { execute_node(conv2d); }
187
188} // namespace locomotiv
2D Spatial Convolution
Definition Nodes.h:554
const Stride< 2 > * stride(void) const
Definition Nodes.h:567
Node * ker(void) const
Definition Nodes.h:559
const Padding2D * pad(void) const
Definition Nodes.h:563
Node * ifm(void) const
Definition Nodes.h:556
uint32_t left(void) const
Definition Padding2D.h:49
uint32_t top(void) const
Definition Padding2D.h:41
uint32_t bottom(void) const
Definition Padding2D.h:45
uint32_t right(void) const
Definition Padding2D.h:53
uint32_t horizontal(void) const
Definition Stride.h:40
uint32_t vertical(void) const
Definition Stride.h:36
const luci_interpreter::RuntimeShape output_shape
bool validate(Code *code)
void validate(bool true_cond, const std::string &&exception_msg)
Definition Validation.h:26
void annot_domain(loco::Node *node, const loco::Domain &domain)
Wrapper to annotate domain to node. Cannot annotate unknown domain.
std::unique_ptr< NodeData > make_data(const NodeData::Buffer< DT > &buffer)
Copy buffer to make NodeData.
Buffer< T > make_buffer(const Shape &shape)
Definition Buffer.h:47
Definition Shape.h:28