45 const int input_height = input_shape.
Dims(1);
46 const int input_width = input_shape.
Dims(2);
61 std::fill(output_data, output_data +
output_shape.FlatSize(), 0.0);
62 std::fill(arg_max_index, arg_max_index +
output_shape.FlatSize(), -1);
66 (pad_height < filter_height) ? 0 : (pad_height - filter_height) / stride_height + 1;
67 const int h_end = std::min((input_height + pad_height - 1) / stride_height + 1, output_height);
70 (pad_width < filter_width) ? 0 : (pad_width - filter_width) / stride_width + 1;
71 const int w_end = std::min((input_width + pad_width - 1) / stride_width + 1, output_width);
73 for (
int b = 0; b < batches; ++b)
75 for (
int h_idx = h_start; h_idx < h_end; h_idx++)
77 for (
int w_idx = w_start; w_idx < w_end; w_idx++)
80 out_mat.col(
offset).setConstant(std::numeric_limits<float>::lowest());
85 for (
int b = 0; b < batches; ++b)
87 for (
int h = 0; h < input_height; ++h)
89 for (
int w = 0; w < input_width; ++w)
93 int hpad = h + pad_height;
94 int wpad = w + pad_width;
96 int h_start = (hpad < filter_height) ? 0 : (hpad - filter_height) / stride_height + 1;
97 int h_end = std::min(hpad / stride_height + 1, output_height);
99 int w_start = (wpad < filter_width) ? 0 : (wpad - filter_width) / stride_width + 1;
100 int w_end = std::min(wpad / stride_width + 1, output_width);
103 for (
int ph = h_start; ph < h_end; ++ph)
105 for (
int pw = w_start; pw < w_end; ++pw)
107 const int out_offset =
NodeOffset(b, ph, pw, output_height, output_width);
108 const int in_offset =
NodeOffset(b, h, w, input_height, input_width);
110 const auto out_vector = out_mat.col(out_offset);
111 const auto in_vector = in_mat.col(in_offset);
114 arg_max_index_mat.col(out_offset) =
115 (out_vector.array() < in_vector.array())
116 .select(in_offset, arg_max_index_mat.col(out_offset));
119 out_mat.col(out_offset) = out_vector.cwiseMax(in_vector);
130 const int *arg_max_index,
const Shape &grad_shape,
float *grad_data)
136 std::fill(grad_data, grad_data + grad_shape.
FlatSize(), 0.0);
138 const int depth =
MatchingDim(grad_shape, 3, incoming_shape, 3);
143 for (
int col_index = 0; col_index < incoming_mat.cols(); col_index++)
145 auto arg_indices = arg_max_index_mat.col(col_index);
146 for (
int d = 0; d < depth; d++)
149 if (arg_indices(d) == -1)
152 grad_mat(d, arg_indices(d)) += incoming_mat(d, col_index);