32 int kwidth,
int stride_width,
int stride_height,
33 int pad_width,
int pad_height,
int in_width,
int in_height,
34 int in_depth,
int single_buffer_length,
int buffer_id,
35 const T *in_data, T *conv_buffer_data, uint8_t zero_byte)
40 const int kwidth_times_indepth = kwidth * in_depth;
41 const int inwidth_times_indepth = in_width * in_depth;
42 const int ih_ungated_start = h * stride_height - pad_height;
43 const int ih_ungated_end = (ih_ungated_start + kheight);
44 const int ih_end = std::min(ih_ungated_end, in_height);
45 const int iw_ungated_start = w * stride_width - pad_width;
46 const int iw_ungated_end = (iw_ungated_start + kwidth);
47 const int iw_end = std::min(iw_ungated_end, in_width);
50 const int h_offset = std::max(0, -ih_ungated_start);
51 const int w_offset = std::max(0, -iw_ungated_start);
52 const int ih_start = std::max(0, ih_ungated_start);
53 const int iw_start = std::max(0, iw_ungated_start);
54 const int single_row_num = std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
55 const int output_row_offset = (buffer_id * single_buffer_length);
56 int out_offset = output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
57 int in_offset =
Offset(input_shape, b, ih_start, iw_start, 0);
60 const int top_padding = h_offset;
61 const int bottom_padding = (ih_ungated_end - ih_end);
62 const int left_padding = w_offset;
63 const int right_padding = (iw_ungated_end - iw_end);
64 assert(single_row_num == ((kwidth - (left_padding + right_padding)) * in_depth));
70 const int top_row_elements = (top_padding * kwidth * in_depth);
71 memset(conv_buffer_data + output_row_offset, zero_byte, (top_row_elements *
sizeof(T)));
76 if ((left_padding == 0) && (right_padding == 0))
78 for (
int ih = ih_start; ih < ih_end; ++ih)
80 memcpy(conv_buffer_data + out_offset, in_data + in_offset, single_row_num *
sizeof(T));
81 out_offset += kwidth_times_indepth;
82 in_offset += inwidth_times_indepth;
87 for (
int ih = ih_start; ih < ih_end; ++ih)
91 const int left_start = (out_offset - (left_padding * in_depth));
92 memset(conv_buffer_data + left_start, zero_byte, (left_padding * in_depth *
sizeof(T)));
94 memcpy(conv_buffer_data + out_offset, in_data + in_offset, single_row_num *
sizeof(T));
95 if (right_padding > 0)
97 const int right_start = (out_offset + single_row_num);
98 memset(conv_buffer_data + right_start, zero_byte, (right_padding * in_depth *
sizeof(T)));
100 out_offset += kwidth_times_indepth;
101 in_offset += inwidth_times_indepth;
107 if (bottom_padding > 0)
109 const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
110 const int bottom_start =
111 output_row_offset + ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
112 memset(conv_buffer_data + bottom_start, zero_byte, (bottom_row_elements *
sizeof(T)));
120 const int32_t *zero_bytes,
const int zero_bytes_len)
135 assert(dilation_width_factor != 1 || dilation_height_factor != 1);
138 const int input_height = input_shape.
Dims(1);
139 const int input_width = input_shape.
Dims(2);
140 const int input_depth =
MatchingDim(input_shape, 3, filter_shape, 3);
141 const int filter_height = filter_shape.
Dims(1);
142 const int filter_width = filter_shape.
Dims(2);
149 const Shape row_shape({1, batches, output_height, output_width});
151 const Shape col_shape({1, filter_height, filter_width, input_depth});
153 const Shape im2col_shape({1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
156 for (
int batch = 0; batch < batches; ++batch)
159 zero_bytes_len > 1 ?
static_cast<T
>(zero_bytes[batch]) :
static_cast<T
>(zero_bytes[0]);
160 for (
int out_y = 0; out_y < output_height; ++out_y)
162 for (
int out_x = 0; out_x < output_width; ++out_x)
166 int row_offset =
Offset(row_shape, 0, batch, out_y, out_x);
167 const int in_x_origin = (out_x * stride_width) - pad_width;
168 const int in_y_origin = (out_y * stride_height) - pad_height;
170 for (
int filter_y = 0; filter_y < filter_height; ++filter_y)
172 const int in_y = in_y_origin + dilation_height_factor * filter_y;
173 if ((in_y >= 0) && (in_y < input_height))
177 for (
int filter_x = 0; filter_x < filter_width; ++filter_x)
179 const int in_x = in_x_origin + dilation_width_factor * filter_x;
180 int col_offset =
Offset(col_shape, 0, filter_y, filter_x, 0);
181 T *dst = im2col_data +
Offset(im2col_shape, 0, 0, row_offset, col_offset);
182 if ((in_x >= 0) && (in_x < input_width))
185 T
const *src = input_data +
Offset(input_shape, batch, in_y, in_x, 0);
186 memcpy(dst, src, input_depth *
sizeof(T));
191 memset(dst, zero_byte, input_depth *
sizeof(T));
198 int col_offset =
Offset(col_shape, 0, filter_y, 0, 0);
199 T *dst = im2col_data +
Offset(im2col_shape, 0, 0, row_offset, col_offset);
200 memset(dst, zero_byte, filter_width * input_depth *
sizeof(T));
231 const int input_depth = input_shape.
Dims(3);
232 const int input_width = input_shape.
Dims(2);
233 const int input_height = input_shape.
Dims(1);
240 for (
int b = 0; b < batches; ++b)
242 for (
int h = 0; h < output_height; ++h)
244 for (
int w = 0; w < output_width; ++w)
247 stride_height, pad_width, pad_height, input_width,
248 input_height, input_depth, output_depth, buffer_id, input_data,
249 output_data, zero_byte);
void Im2col(const ConvParams ¶ms, int kheight, int kwidth, uint8_t zero_byte, const Shape &input_shape, const T *input_data, const Shape &output_shape, T *output_data)
void ExtractPatchIntoBufferColumn(const Shape &input_shape, int w, int h, int b, int kheight, int kwidth, int stride_width, int stride_height, int pad_width, int pad_height, int in_width, int in_height, int in_depth, int single_buffer_length, int buffer_id, const T *in_data, T *conv_buffer_data, uint8_t zero_byte)
void DilatedIm2col(const ConvParams ¶ms, const Shape &input_shape, const T *input_data, const Shape &filter_shape, const Shape &output_shape, T *im2col_data, const int32_t *zero_bytes, const int zero_bytes_len)