51typedef IndexList<type2index<0>, type2index<0>, type2index<1>, type2index<1>>
ReverseColMajor;
52typedef IndexList<type2index<1>, type2index<1>, type2index<0>, type2index<0>>
ReverseRowMajor;
54template <
typename OutputBackward,
typename Kernel>
55EIGEN_ALWAYS_INLINE
static const std::conditional_t<
56 internal::traits<OutputBackward>::Layout == ColMajor,
58 const DSizes<typename internal::traits<OutputBackward>::Index,
59 internal::traits<OutputBackward>::NumDimensions>,
60 const TensorContractionOp<
61 const array<IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
62 const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 2>,
63 const Eigen::TensorForcedEvalOp<
const TensorShufflingOp<
64 const array<typename internal::traits<OutputBackward>::Index, 4>,
65 const Eigen::TensorForcedEvalOp<
66 const TensorReverseOp<const ReverseColMajor, const Kernel>>>>>,
67 const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 2>,
68 const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward>>>>,
71 const DSizes<typename internal::traits<OutputBackward>::Index,
72 internal::traits<OutputBackward>::NumDimensions>,
73 const TensorContractionOp<
74 const array<IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
75 const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 2>,
76 const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward>>,
77 const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 2>,
78 const Eigen::TensorForcedEvalOp<
const TensorShufflingOp<
79 const array<typename internal::traits<OutputBackward>::Index, 4>,
80 const Eigen::TensorForcedEvalOp<
81 const TensorReverseOp<const ReverseRowMajor, const Kernel>>>>>>>>
82SpatialConvolutionBackwardInput(
const Kernel &kernel,
const OutputBackward &output_backward,
83 typename internal::traits<OutputBackward>::Index inputRows,
84 typename internal::traits<OutputBackward>::Index inputCols,
85 const DenseIndex row_stride = 1,
const DenseIndex col_stride = 1,
86 const DenseIndex row_in_stride = 1,
87 const DenseIndex col_in_stride = 1)
89 typedef typename internal::traits<OutputBackward>::Index TensorIndex;
90 typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
93 internal::traits<Kernel>::Layout, TensorIndex>>
95 TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions,
96 internal::traits<OutputBackward>::Layout, TensorIndex>>
99 EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout == internal::traits<OutputBackward>::Layout,
100 YOU_MADE_A_PROGRAMMING_MISTAKE);
102 static const bool isColMajor = (internal::traits<OutputBackward>::Layout == ColMajor);
104 static const int NumDims = internal::traits<OutputBackward>::NumDimensions;
108 const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
110 const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
111 const TensorIndex kernelRows = isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
112 const TensorIndex kernelCols = isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
117 const TensorIndex kernelRowsEff = kernelRows + (kernelRows - 1) * (row_in_stride - 1);
118 const TensorIndex kernelColsEff = kernelCols + (kernelCols - 1) * (col_in_stride - 1);
120 const TensorIndex outputRows =
121 isColMajor ? output_backward.dimension(1) : output_backward.dimension(NumDims - 2);
122 const TensorIndex outputCols =
123 isColMajor ? output_backward.dimension(2) : output_backward.dimension(NumDims - 3);
126 const TensorIndex forward_pad_top =
127 numext::maxi<Index>(0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2);
128 const TensorIndex forward_pad_left =
129 numext::maxi<Index>(0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2);
130 const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
131 const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
133 const TensorIndex padding_bottom =
134 inputRows - (outputRows - 1) * row_stride - 2 - padding_top + kernelRowsEff;
135 const TensorIndex padding_right =
136 inputCols - (outputCols - 1) * col_stride - 2 - padding_left + kernelColsEff;
138 eigen_assert(padding_top >= 0);
139 eigen_assert(padding_left >= 0);
140 eigen_assert(padding_bottom >= 0);
141 eigen_assert(padding_right >= 0);
149 typedef std::conditional_t<isColMajor, ReverseColMajor, ReverseRowMajor> Reverse;
150 Reverse kernel_reverse;
153 array<TensorIndex, 4> kernel_shuffle;
158 kernel_shuffle[0] = 0;
159 kernel_shuffle[1] = 2;
160 kernel_shuffle[2] = 3;
161 kernel_shuffle[3] = 1;
167 kernel_shuffle[0] = 2;
168 kernel_shuffle[1] = 0;
169 kernel_shuffle[2] = 1;
170 kernel_shuffle[3] = 3;
174 DSizes<TensorIndex, 2> kernel_dims;
177 kernel_dims[0] = kernelFilters * kernelRows * kernelCols;
178 kernel_dims[1] = kernelChannels;
182 kernel_dims[1] = kernelFilters * kernelRows * kernelCols;
183 kernel_dims[0] = kernelChannels;
191 DSizes<TensorIndex, 2> pre_contract_dims;
194 pre_contract_dims[0] = kernelFilters * kernelRows * kernelCols;
195 pre_contract_dims[1] = inputRows * inputCols;
196 for (
int i = 3; i < NumDims; ++i)
198 pre_contract_dims[1] *= out.dimension(i);
203 pre_contract_dims[1] = kernelFilters * kernelRows * kernelCols;
204 pre_contract_dims[0] = inputRows * inputCols;
205 for (
int i = 0; i < NumDims - 3; ++i)
207 pre_contract_dims[0] *= out.dimension(i);
213 array<IndexPair<TensorIndex>, 1> contract_dims;
217 contract_dims[0] = IndexPair<TensorIndex>(0, 0);
222 contract_dims[0] = IndexPair<TensorIndex>(1, 1);
227 DSizes<TensorIndex, NumDims> post_contract_dims;
230 post_contract_dims[0] = kernelChannels;
231 post_contract_dims[1] = inputRows;
232 post_contract_dims[2] = inputCols;
233 for (
int i = 3; i < NumDims; ++i)
235 post_contract_dims[i] = out.dimension(i);
240 post_contract_dims[NumDims - 1] = kernelChannels;
241 post_contract_dims[NumDims - 2] = inputRows;
242 post_contract_dims[NumDims - 3] = inputCols;
243 for (
int i = 0; i < NumDims - 3; ++i)
245 post_contract_dims[i] = out.dimension(i);
254 Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
255 kernel.reverse(kernel_reverse)
257 .shuffle(kernel_shuffle)
259 .reshape(kernel_dims)
260 .contract(output_backward
261 .extract_image_patches(kernelRows, kernelCols, 1, 1, row_in_stride, col_in_stride,
262 row_stride, col_stride, padding_top, padding_bottom,
263 padding_left, padding_right, OutScalar(0))
264 .reshape(pre_contract_dims),
266 .reshape(post_contract_dims),
268 .extract_image_patches(kernelRows, kernelCols, 1, 1, row_in_stride, col_in_stride, row_stride,
269 col_stride, padding_top, padding_bottom, padding_left, padding_right,
271 .reshape(pre_contract_dims)
273 kernel.reverse(kernel_reverse).eval().shuffle(kernel_shuffle).eval().reshape(kernel_dims),
275 .reshape(post_contract_dims));
303template <
typename OutputBackward,
typename Input>
304EIGEN_ALWAYS_INLINE
static const std::conditional_t<
305 internal::traits<Input>::Layout == ColMajor,
306 const TensorReverseOp<
307 const Eigen::array<typename internal::traits<Input>::Index,
308 internal::traits<Input>::NumDimensions>,
309 const Eigen::TensorForcedEvalOp<
const Eigen::TensorShufflingOp<
310 const Eigen::array<typename internal::traits<Input>::Index,
311 internal::traits<Input>::NumDimensions>,
312 const Eigen::TensorReshapingOp<
313 const Eigen::DSizes<typename internal::traits<Input>::Index,
314 internal::traits<Input>::NumDimensions>,
315 const TensorContractionOp<
316 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
317 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
318 const Eigen::TensorForcedEvalOp<
const Eigen::TensorShufflingOp<
319 const Eigen::array<typename internal::traits<Input>::Index,
320 internal::traits<Input>::NumDimensions>,
322 const TensorReshapingOp<
323 const DSizes<typename internal::traits<Input>::Index, 2>,
324 const TensorImagePatchOp<Dynamic, Dynamic,
325 const Eigen::TensorForcedEvalOp<
const Eigen::TensorShufflingOp<
326 const Eigen::array<typename internal::traits<Input>::Index,
327 internal::traits<Input>::NumDimensions>,
328 const OutputBackward>>>>>>>>>,
329 const TensorReverseOp<
330 const Eigen::array<typename internal::traits<Input>::Index,
331 internal::traits<Input>::NumDimensions>,
332 const Eigen::TensorForcedEvalOp<
const Eigen::TensorShufflingOp<
333 const Eigen::array<typename internal::traits<Input>::Index,
334 internal::traits<Input>::NumDimensions>,
335 const Eigen::TensorReshapingOp<
336 const Eigen::DSizes<typename internal::traits<Input>::Index,
337 internal::traits<Input>::NumDimensions>,
338 const TensorContractionOp<
339 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
340 const TensorReshapingOp<
341 const DSizes<typename internal::traits<Input>::Index, 2>,
342 const TensorImagePatchOp<Dynamic, Dynamic,
343 const Eigen::TensorForcedEvalOp<
const Eigen::TensorShufflingOp<
344 const Eigen::array<typename internal::traits<Input>::Index,
345 internal::traits<Input>::NumDimensions>,
346 const OutputBackward>>>>,
347 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
348 const Eigen::TensorForcedEvalOp<
const Eigen::TensorShufflingOp<
349 const Eigen::array<typename internal::traits<Input>::Index,
350 internal::traits<Input>::NumDimensions>,
352SpatialConvolutionBackwardKernel(
const Input &input,
const OutputBackward &output_backward,
353 typename internal::traits<Input>::Index kernelRows,
354 typename internal::traits<Input>::Index kernelCols,
355 const DenseIndex row_stride = 1,
const DenseIndex col_stride = 1,
356 const DenseIndex row_in_stride = 1,
357 const DenseIndex col_in_stride = 1)
359 typedef typename internal::traits<Input>::Index TensorIndex;
360 typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
361 TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions,
362 internal::traits<Input>::Layout, TensorIndex>>
364 TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions,
365 internal::traits<OutputBackward>::Layout, TensorIndex>>
366 out(output_backward);
368 EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == internal::traits<OutputBackward>::Layout,
369 YOU_MADE_A_PROGRAMMING_MISTAKE);
372 eigen_assert(!(row_stride > 1 && row_in_stride > 1));
373 eigen_assert(!(col_stride > 1 && col_in_stride > 1));
375 static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
377 static const int NumDims = internal::traits<Input>::NumDimensions;
378 EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions ==
379 internal::traits<OutputBackward>::NumDimensions,
380 YOU_MADE_A_PROGRAMMING_MISTAKE);
381 EIGEN_STATIC_ASSERT(NumDims == 4, YOU_MADE_A_PROGRAMMING_MISTAKE);
383 const TensorIndex inputRows = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
384 const TensorIndex inputCols = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
386 const TensorIndex outputRows =
387 isColMajor ? output_backward.dimension(1) : output_backward.dimension(NumDims - 2);
388 const TensorIndex outputCols =
389 isColMajor ? output_backward.dimension(2) : output_backward.dimension(NumDims - 3);
393 const TensorIndex kernelFilters =
394 isColMajor ? out.dimensions()[0] : out.dimensions()[NumDims - 1];
397 const TensorIndex kernelChannels = isColMajor ? in.dimensions()[0] : in.dimensions()[NumDims - 1];
402 const TensorIndex kernelRowsEff = kernelRows + (kernelRows - 1) * (row_in_stride - 1);
403 const TensorIndex kernelColsEff = kernelCols + (kernelCols - 1) * (col_in_stride - 1);
406 TensorIndex batch = 1;
407 for (
int d = 3; d < NumDims; ++d)
409 batch *= isColMajor ? in.dimension(d) : in.dimension(NumDims - d - 1);
414 numext::maxi<Index>(0, (outputRows - 1) * row_stride + kernelRowsEff - inputRows);
416 numext::maxi<Index>(0, (outputCols - 1) * col_stride + kernelColsEff - inputCols);
422 const TensorIndex expanded_out_rows = (outputRows - 1) * row_stride + 1;
423 const TensorIndex expanded_out_cols = (outputCols - 1) * col_stride + 1;
425 const TensorIndex padded_out_rows = inputRows + kernelRowsEff - 1;
426 const TensorIndex padded_out_cols = inputCols + kernelColsEff - 1;
428 const TensorIndex top_pad_rows = kernelRowsEff - 1 - padding_top;
429 const TensorIndex left_pad_cols = kernelColsEff - 1 - padding_left;
431 const TensorIndex bottom_pad_rows = padded_out_rows - expanded_out_rows - top_pad_rows;
432 const TensorIndex right_pad_cols = padded_out_cols - expanded_out_cols - left_pad_cols;
435 array<TensorIndex, 4> output_backward_shuffle;
440 output_backward_shuffle = {3, 1, 2, 0};
446 output_backward_shuffle = {3, 1, 2, 0};
450 array<TensorIndex, 4> input_shuffle;
455 input_shuffle = {0, 3, 1, 2};
461 input_shuffle = {1, 2, 0, 3};
465 DSizes<TensorIndex, 2> input_dims;
468 input_dims[0] = kernelChannels;
469 input_dims[1] = batch * inputRows * inputCols;
473 input_dims[1] = kernelChannels;
474 input_dims[0] = inputCols * inputRows * batch;
481 DSizes<TensorIndex, 2> pre_contract_dims;
484 pre_contract_dims[0] = batch * inputRows * inputCols;
485 pre_contract_dims[1] = kernelRows * kernelCols * kernelFilters;
489 pre_contract_dims[1] = inputCols * inputRows * batch;
490 pre_contract_dims[0] = kernelFilters * kernelCols * kernelRows;
495 array<IndexPair<TensorIndex>, 1> contract_dims;
496 contract_dims[0] = IndexPair<TensorIndex>(1, 0);
499 DSizes<TensorIndex, NumDims> post_contract_dims;
502 post_contract_dims[0] = kernelChannels;
503 post_contract_dims[1] = kernelRows;
504 post_contract_dims[2] = kernelCols;
505 post_contract_dims[3] = kernelFilters;
509 post_contract_dims[0] = kernelFilters;
510 post_contract_dims[1] = kernelCols;
511 post_contract_dims[2] = kernelRows;
512 post_contract_dims[3] = kernelChannels;
516 array<TensorIndex, 4> kernel_shuffle;
521 kernel_shuffle = {3, 0, 1, 2};
527 kernel_shuffle = {1, 2, 3, 0};
531 array<TensorIndex, 4> kernel_reverse;
534 kernel_reverse = {
false,
false,
true,
true};
538 kernel_reverse = {
true,
true,
false,
false};
543 const auto output_backward_shuffled = output_backward.shuffle(output_backward_shuffle).eval();
547 const auto input_shuffled =
input.shuffle(input_shuffle).eval().reshape(input_dims);
549 return choose(Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
550 input_shuffled.contract(
551 output_backward_shuffled
552 .extract_image_patches(inputRows, inputCols, row_in_stride, col_in_stride, 1, 1,
553 row_stride, col_stride, top_pad_rows, bottom_pad_rows,
554 left_pad_cols, right_pad_cols, OutScalar(0))
557 output_backward_shuffled
558 .extract_image_patches(inputRows, inputCols, row_in_stride, col_in_stride, 1, 1,
559 row_stride, col_stride, top_pad_rows, bottom_pad_rows,
560 left_pad_cols, right_pad_cols, OutScalar(0))
562 .contract(input_shuffled, contract_dims))
563 .reshape(post_contract_dims)
564 .shuffle(kernel_shuffle)
566 .reverse(kernel_reverse);