30inline void Pad(
const int32_t *padding_data, int32_t pad_rank,
const Shape &input_shape,
32 const T *constant_value_data)
38 using PaddingInfo = std::pair<int32_t, int32_t>;
40 using PaddingList = std::vector<PaddingInfo>;
42 const T constant_value = constant_value_data ? *constant_value_data : 0;
45 PaddingList padding_list(pad_rank);
46 for (int32_t n = 0; n < pad_rank; ++n)
48 const int32_t *from = padding_data + (n * 2);
49 padding_list[n] = {from[0], from[1]};
51 for (int32_t i = 0; i < pad_rank; ++i)
54 input_shape.
Dims(i) + padding_list[i].first + padding_list[i].second);
66 const int32_t in_row_len = input_shape.
Dims(0);
67 [[maybe_unused]]
auto [pad_before, pad_after] = padding_list[0];
68 std::fill_n(output_data, pad_before, constant_value);
69 std::memcpy(output_data + pad_before, input_data, in_row_len *
sizeof(T));
70 std::fill_n(output_data + pad_before + in_row_len, pad_after, constant_value);
75 const int32_t in_row_len = input_shape.
Dims(1);
78 auto [pad_top, pad_bottom] = padding_list[0];
79 auto [pad_left, pad_right] = padding_list[1];
82 std::fill_n(output_data, pad_top * out_row_size, constant_value);
84 const auto r_h_inp_lim = input_shape.
Dims(0) + pad_top;
85 for (
auto i = pad_top, j = 0; i < r_h_inp_lim; ++i, ++j)
87 auto out_offset = i * out_row_size;
88 const auto in_offset = j * in_row_len;
91 std::fill_n(output_data + out_offset, pad_left, constant_value);
92 out_offset += pad_left;
95 memcpy(output_data + out_offset, input_data + in_offset, in_row_len *
sizeof(T));
96 out_offset += in_row_len;
99 std::fill_n(output_data + out_offset, pad_right, constant_value);
103 std::fill_n(output_data + r_h_inp_lim * out_row_size, pad_bottom * out_row_size,
109 const int32_t in_row_len = input_shape.
Dims(2);
111 const auto plain_size = out_row_size *
output_shape.Dims(1);
113 auto [pad_batches_before, pad_batches_after] = padding_list[0];
114 auto [pad_parallelepipes_before, pad_parallelepipes_after] = padding_list[1];
115 auto [pad_plains_before, pad_plains_after] = padding_list[2];
118 std::fill_n(output_data, pad_batches_before * plain_size, constant_value);
120 const auto r_h_inp_lim = input_shape.
Dims(0) + pad_batches_before;
121 for (
auto i = pad_batches_before, i_inp = 0; i < r_h_inp_lim; ++i, ++i_inp)
126 std::fill_n(output_data + out_w_offset, pad_parallelepipes_before * out_row_size,
129 const auto r_w_inp_lim = input_shape.
Dims(1) + pad_parallelepipes_before;
130 for (
auto j = pad_parallelepipes_before, j_inp = 0; j < r_w_inp_lim; ++j, ++j_inp)
133 const auto in_offset = (i_inp * input_shape.
Dims(1) + j_inp) * input_shape.
Dims(2);
136 std::fill_n(output_data + out_offset, pad_plains_before, constant_value);
137 out_offset += pad_plains_before;
140 memcpy(output_data + out_offset, input_data + in_offset, in_row_len *
sizeof(T));
141 out_offset += in_row_len;
144 std::fill_n(output_data + out_offset, pad_plains_after, constant_value);
148 std::fill_n(output_data + out_w_offset + r_w_inp_lim * out_row_size,
149 pad_parallelepipes_after * out_row_size, constant_value);
153 std::fill_n(output_data + r_h_inp_lim * plain_size, pad_batches_after * plain_size,
159 auto get_offset = [](
const Shape &shape, int32_t n, int32_t h, int32_t w) -> int32_t {
160 return ((n * shape.
Dims(1) + h) * shape.
Dims(2) + w) * shape.
Dims(3);
162 const int32_t in_row_len = input_shape.
Dims(3);
164 const auto plain_size = out_row_size *
output_shape.Dims(2);
165 const auto parallelepiped_size = plain_size *
output_shape.Dims(1);
167 auto [pad_batches_before, pad_batches_after] = padding_list[0];
168 auto [pad_parallelepipes_before, pad_parallelepipes_after] = padding_list[1];
169 auto [pad_plains_before, pad_plains_after] = padding_list[2];
170 auto [pad_rows_before, pad_rows_after] = padding_list[3];
173 std::fill_n(output_data, pad_batches_before * parallelepiped_size, constant_value);
175 const auto r_b_inp_lim = input_shape.
Dims(0) + pad_batches_before;
176 for (
auto i = pad_batches_before, i_inp = 0; i < r_b_inp_lim; ++i, ++i_inp)
178 const auto out_h_offset = get_offset(
output_shape, i, 0, 0);
180 std::fill_n(output_data + out_h_offset, pad_parallelepipes_before * plain_size,
183 const auto r_h_inp_lim = input_shape.
Dims(1) + pad_parallelepipes_before;
184 for (
auto j = pad_parallelepipes_before, j_inp = 0; j < r_h_inp_lim; ++j, ++j_inp)
186 const auto out_w_offset = get_offset(
output_shape, i, j, 0);
189 std::fill_n(output_data + out_w_offset, pad_plains_before * out_row_size, constant_value);
191 const auto r_w_inp_lim = input_shape.
Dims(2) + pad_plains_before;
192 for (
auto k = pad_plains_before, k_inp = 0; k < r_w_inp_lim; ++k, ++k_inp)
195 const auto in_offset = get_offset(input_shape, i_inp, j_inp, k_inp);
198 std::fill_n(output_data + out_c_offset, pad_rows_before, constant_value);
199 out_c_offset += pad_rows_before;
202 memcpy(output_data + out_c_offset, input_data + in_offset, in_row_len *
sizeof(T));
203 out_c_offset += in_row_len;
206 std::fill_n(output_data + out_c_offset, pad_rows_after, constant_value);
210 std::fill_n(output_data + out_w_offset + r_w_inp_lim * out_row_size,
211 pad_plains_after * out_row_size, constant_value);
215 std::fill_n(output_data + out_h_offset + r_h_inp_lim * plain_size,
216 pad_parallelepipes_after * plain_size, constant_value);
220 std::fill_n(output_data + r_b_inp_lim * parallelepiped_size,
221 pad_batches_after * parallelepiped_size, constant_value);
226 throw std::runtime_error(
"Padding for rank > 4 NYI");