30inline void Pad(
const int32_t *padding_data, int32_t pad_rank,
const Shape &input_shape,
32 const T *constant_value_data)
38 using PaddingInfo = std::pair<int32_t, int32_t>;
40 using PaddingList = std::vector<PaddingInfo>;
42 const T constant_value = constant_value_data ? *constant_value_data : 0;
45 PaddingList padding_list(pad_rank);
46 for (int32_t n = 0; n < pad_rank; ++n)
48 const int32_t *from = padding_data + (n * 2);
49 padding_list[n] = {from[0], from[1]};
51 for (int32_t i = 0; i < pad_rank; ++i)
54 input_shape.
Dims(i) + padding_list[i].first + padding_list[i].second);
66 const int32_t in_row_len = input_shape.
Dims(0);
67 std::fill_n(output_data, padding_list[0].first, constant_value);
68 std::memcpy(output_data + padding_list[0].first, input_data, in_row_len *
sizeof(T));
69 std::fill_n(output_data + padding_list[0].first + in_row_len, padding_list[0].second,
75 const int32_t in_row_len = input_shape.
Dims(1);
79 std::fill_n(output_data, padding_list[0].first * out_row_size, constant_value);
81 const auto r_h_inp_lim = input_shape.
Dims(0) + padding_list[0].first;
82 for (
auto i = padding_list[0].first, j = 0; i < r_h_inp_lim; ++i, ++j)
84 auto out_offset = i * out_row_size;
85 const auto in_offset = j * in_row_len;
88 std::fill_n(output_data + out_offset, padding_list[1].first, constant_value);
90 out_offset += padding_list[1].first;
93 memcpy(output_data + out_offset, input_data + in_offset, in_row_len *
sizeof(T));
95 out_offset += in_row_len;
98 std::fill_n(output_data + out_offset, padding_list[1].second, constant_value);
102 std::fill_n(output_data + r_h_inp_lim * out_row_size, padding_list[0].second * out_row_size,
108 const int32_t in_row_len = input_shape.
Dims(2);
110 const auto plain_size = out_row_size *
output_shape.Dims(1);
113 std::fill_n(output_data, padding_list[0].first * plain_size, constant_value);
115 const auto r_h_inp_lim = input_shape.
Dims(0) + padding_list[0].first;
116 for (
auto i = padding_list[0].first, i_inp = 0; i < r_h_inp_lim; ++i, ++i_inp)
121 std::fill_n(output_data + out_w_offset, padding_list[1].first * out_row_size,
124 const auto r_w_inp_lim = input_shape.
Dims(1) + padding_list[1].first;
125 for (
auto j = padding_list[1].first, j_inp = 0; j < r_w_inp_lim; ++j, ++j_inp)
128 const auto in_offset = (i_inp * input_shape.
Dims(1) + j_inp) * input_shape.
Dims(2);
131 std::fill_n(output_data + out_offset, padding_list[2].first, constant_value);
133 out_offset += padding_list[2].first;
136 memcpy(output_data + out_offset, input_data + in_offset, in_row_len *
sizeof(T));
138 out_offset += in_row_len;
141 std::fill_n(output_data + out_offset, padding_list[2].second, constant_value);
145 std::fill_n(output_data + out_w_offset + r_w_inp_lim * out_row_size,
146 padding_list[1].second * out_row_size, constant_value);
150 std::fill_n(output_data + r_h_inp_lim * plain_size, padding_list[0].second * plain_size,
156 auto get_offset = [](
const Shape &shape, int32_t n, int32_t h, int32_t w) -> int32_t {
157 return ((n * shape.
Dims(1) + h) * shape.
Dims(2) + w) * shape.
Dims(3);
159 const int32_t in_row_len = input_shape.
Dims(3);
161 const auto plain_size = out_row_size *
output_shape.Dims(2);
162 const auto parallelepiped_size = plain_size *
output_shape.Dims(1);
165 std::fill_n(output_data, padding_list[0].first * parallelepiped_size, constant_value);
167 const auto r_b_inp_lim = input_shape.
Dims(0) + padding_list[0].first;
168 for (
auto i = padding_list[0].first, i_inp = 0; i < r_b_inp_lim; ++i, ++i_inp)
170 const auto out_h_offset = get_offset(
output_shape, i, 0, 0);
172 std::fill_n(output_data + out_h_offset, padding_list[1].first * plain_size, constant_value);
174 const auto r_h_inp_lim = input_shape.
Dims(1) + padding_list[1].first;
175 for (
auto j = padding_list[1].first, j_inp = 0; j < r_h_inp_lim; ++j, ++j_inp)
177 const auto out_w_offset = get_offset(
output_shape, i, j, 0);
180 std::fill_n(output_data + out_w_offset, padding_list[2].first * out_row_size,
183 const auto r_w_inp_lim = input_shape.
Dims(2) + padding_list[2].first;
184 for (
auto k = padding_list[2].first, k_inp = 0; k < r_w_inp_lim; ++k, ++k_inp)
187 const auto in_offset = get_offset(input_shape, i_inp, j_inp, k_inp);
190 std::fill_n(output_data + out_c_offset, padding_list[3].first, constant_value);
192 out_c_offset += padding_list[3].first;
195 memcpy(output_data + out_c_offset, input_data + in_offset, in_row_len *
sizeof(T));
197 out_c_offset += in_row_len;
200 std::fill_n(output_data + out_c_offset, padding_list[3].second, constant_value);
204 std::fill_n(output_data + out_w_offset + r_w_inp_lim * out_row_size,
205 padding_list[2].second * out_row_size, constant_value);
209 std::fill_n(output_data + out_h_offset + r_h_inp_lim * plain_size,
210 padding_list[1].second * plain_size, constant_value);
213 std::fill_n(output_data + r_b_inp_lim * parallelepiped_size,
214 padding_list[0].second * parallelepiped_size, constant_value);
218 throw std::runtime_error(
"Padding for rank > 4 NYI");