64class TensorContractionInputMapper<
67 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
68 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
73 typedef TensorContractionInputMapper<
76 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
77 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
80 typedef TensorContractionSubMapper<
83 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
84 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
89 typedef typename packet_traits<Scalar>::type
Packet;
95 const TensorEvaluator<
96 const TensorReshapingOp<NewDimension,
const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>
98 const nocontract_t &,
const nocontract_t &,
const contract_t &,
const contract_t &)
99 : m_impl(tensor.impl().impl())
103 if (internal::traits<ArgType>::Layout == ColMajor)
105 patch_depth = tensor.impl().dimensions()[0];
106 patch_rows = tensor.impl().dimensions()[1];
107 m_patch_cols = tensor.impl().dimensions()[2];
108 m_num_patches = tensor.impl().dimensions()[3];
112 const size_t NumDims = tensor.impl().dimensions().size();
113 patch_depth = tensor.impl().dimensions()[NumDims - 1];
114 patch_rows = tensor.impl().dimensions()[NumDims - 2];
115 m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
116 m_num_patches = tensor.impl().dimensions()[NumDims - 4];
120 m_patch_row_stride = patch_depth;
121 m_patch_col_stride = patch_rows * m_patch_row_stride;
123 m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
124 m_patch_col_inflate_strides = tensor.impl().colInflateStride();
126 m_colStride = patch_rows;
128 m_outputRows = tensor.impl().outputRows();
129 m_outputCols = tensor.impl().outputCols();
130 m_row_strides = tensor.impl().userRowStride();
131 m_col_strides = tensor.impl().userColStride();
133 m_in_row_strides = tensor.impl().userInRowStride();
134 m_in_col_strides = tensor.impl().userInColStride();
136 if (internal::traits<ArgType>::Layout == ColMajor)
138 m_inputRows = tensor.impl().impl().dimensions()[1];
139 m_inputCols = tensor.impl().impl().dimensions()[2];
143 const int NumDims = tensor.impl().impl().dimensions().size();
144 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
145 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
148 m_rowInputStride = patch_depth;
149 m_colInputStride = patch_depth * m_inputRows;
150 m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
152 m_rowPaddingTop = tensor.impl().rowPaddingTop();
153 m_colPaddingLeft = tensor.impl().colPaddingLeft();
155 m_fastPatchRowStride = internal::TensorIntDivisor<Index>(m_patch_row_stride);
156 m_fastPatchColStride = internal::TensorIntDivisor<Index>(m_patch_col_stride);
157 m_fastInputRowStride = internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
158 m_fastInputColStride = internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
159 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
160 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
161 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
162 m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
167 : m_impl(base_mapper.m_impl)
169 m_patch_cols = base_mapper.m_patch_cols;
170 m_num_patches = base_mapper.m_num_patches;
172 m_patch_row_stride = base_mapper.m_patch_row_stride;
173 m_patch_col_stride = base_mapper.m_patch_col_stride;
175 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
176 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
178 m_colStride = base_mapper.m_colStride;
180 m_rowInputStride = base_mapper.m_rowInputStride;
181 m_colInputStride = base_mapper.m_colInputStride;
182 m_patchInputStride = base_mapper.m_patchInputStride;
184 m_inputRows = base_mapper.m_inputRows;
185 m_inputCols = base_mapper.m_inputCols;
187 m_outputRows = base_mapper.m_outputRows;
188 m_outputCols = base_mapper.m_outputCols;
189 m_row_strides = base_mapper.m_row_strides;
190 m_col_strides = base_mapper.m_col_strides;
192 m_in_row_strides = base_mapper.m_in_row_strides;
193 m_in_col_strides = base_mapper.m_in_col_strides;
195 m_rowPaddingTop = base_mapper.m_rowPaddingTop;
196 m_colPaddingLeft = base_mapper.m_colPaddingLeft;
198 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
199 m_fastPatchColStride = base_mapper.m_fastPatchColStride;
200 m_fastInputRowStride = base_mapper.m_fastInputRowStride;
201 m_fastInputColStride = base_mapper.m_fastInputColStride;
202 m_fastNumPatches = base_mapper.m_fastNumPatches;
203 m_fastColStride = base_mapper.m_fastColStride;
204 m_fastOutputRows = base_mapper.m_fastOutputRows;
205 m_fastDimZero = base_mapper.m_fastDimZero;
214 return m_in_row_strides != 1 || m_in_col_strides != 1 || m_patch_row_inflate_strides != 1 ||
215 m_patch_col_inflate_strides != 1;
233 Index rowIndex, colIndex, otherIndex;
234 computeBaseIndices(0, rowIndex, colIndex, otherIndex);
235 return loadCoeff(row, rowIndex, colIndex, otherIndex);
245 Index rowIndex, colIndex, otherIndex;
246 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
247 return loadCoeff(row, rowIndex, colIndex, otherIndex);
253 Index rowIndex, colIndex, otherIndex;
254 computeBaseIndices(0, rowIndex, colIndex, otherIndex);
255 return loadPacket(row, rowIndex, colIndex, otherIndex);
263 Index rowIndex, colIndex, otherIndex;
264 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
265 return loadPacket(row, rowIndex, colIndex, otherIndex);
269 EIGEN_ALWAYS_INLINE
const TensorEvaluator<ArgType, Device> &
impl()
const {
return m_impl; }
272 EIGEN_ALWAYS_INLINE Index
patchDepth()
const {
return m_rowInputStride; }
274 EIGEN_ALWAYS_INLINE Index
patchRows()
const {
return m_colStride; }
276 EIGEN_ALWAYS_INLINE Index
patchCols()
const {
return m_patch_cols; }
279 friend class TensorContractionSubMapper<
282 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
283 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>;
288 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, Index colIndex,
289 Index otherIndex)
const
292 const Index patchOffset = patchId / m_fastDimZero;
294 const Index colOffset = patchOffset / m_fastColStride;
295 const Index inputCol = colIndex + colOffset * m_in_col_strides;
296 const Index origInputCol = (m_patch_col_inflate_strides == 1)
298 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
300 const Index rowOffset = patchOffset - colOffset * m_colStride;
301 const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
302 const Index origInputRow = (m_patch_row_inflate_strides == 1)
304 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
305 if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
306 origInputRow >= m_inputRows || (inputCol != origInputCol * m_patch_col_inflate_strides) ||
307 (inputRow != origInputRow * m_patch_row_inflate_strides))
311 const Index depth = patchId - patchOffset * patchDepth();
312 const Index inputIndex =
313 depth + origInputRow * m_rowInputStride + origInputCol * m_colInputStride + otherIndex;
314 return m_impl.coeff(inputIndex);
320 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, Index colIndex,
321 Index otherIndex)
const
323 eigen_assert(!nonStandardPatches());
326 const Index patchOffset = patchId / m_fastDimZero;
327 const Index colOffset = patchOffset / m_fastColStride;
328 const Index rowOffset = patchOffset - colOffset * m_colStride;
329 const Index inputCol = colIndex + colOffset;
330 const Index inputRow = rowIndex + rowOffset;
331 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || inputRow >= m_inputRows)
335 const Index depth = patchId - patchOffset * patchDepth();
336 const Index inputIndex =
337 depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
338 return m_impl.coeff(inputIndex);
344 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, Index colIndex,
345 Index otherIndex)
const
347 const Index packetSize = internal::unpacket_traits<Packet>::size;
348 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
349 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
351 if (nonStandardPatches())
353 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
355 typedef decltype(m_impl) TensorEvaluatorT;
356 return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex, colIndex, otherIndex);
367 EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(Index rowIndex, Index colIndex,
368 Index otherIndex, Index patchId,
370 const Index patchOffsets[],
371 Index colOffset)
const
373 const Index inputCol = colIndex + colOffset;
374 const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride,
375 patchOffsets[1] - colOffset * m_colStride};
376 const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
378 if (inputRows[0] >= m_inputRows || inputRows[1] < 0 || inputCol >= m_inputCols || inputCol < 0)
381 return internal::pset1<Packet>(Scalar(0));
383 else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
387 const Index depth = patchId - patchOffsets[0] * patchDepth();
388 const Index inputIndex =
389 depth + inputRows[0] * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
390 return m_impl.template partialPacket<Packet>(inputIndex - span[0],
391 mask<Packet>(span[0], span[1] + 1));
401 const Index packetSize = internal::unpacket_traits<Packet>::size;
403 std::remove_const_t<Scalar> values[packetSize];
404 for (
int i = 0; i < span[0]; ++i)
405 values[i] = Scalar(0);
406 for (
int i = span[0]; i < span[1] + 1; ++i)
407 values[i] = loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex);
408 for (
int i = span[1] + 1; i < packetSize; ++i)
409 values[i] = Scalar(0);
410 return internal::pload<Packet>(values);
419 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns(Index patchId, Index rowIndex,
420 Index colIndex, Index otherIndex,
421 const Index patchOffsets[],
422 const Index colOffsets[])
const
424 eigen_assert(colOffsets[1] == colOffsets[0] + 1);
425 const Index packetSize = internal::unpacket_traits<Packet>::size;
429 const Index patchIdSplit = ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1;
430 const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
436 const Index patchIds[2] = {patchId, patchIdSplit + 1};
437 const Index spans[2][2] = {{0, patchIdSplit - patchId},
438 {patchIdSplit - patchId + 1, packetSize - 1}};
439 const Index patchOffsets2Cols[2][2] = {{patchOffsets[0], patchOffsetSplit},
440 {patchOffsetSplit + 1, patchOffsets[1]}};
443 return internal::por<Packet>(
444 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0], spans[0],
445 patchOffsets2Cols[0], colOffsets[0]),
446 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1], spans[1],
447 patchOffsets2Cols[1], colOffsets[1]));
453 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn(Index patchId, Index rowIndex,
454 Index colIndex, Index otherIndex,
455 const Index patchOffsets[],
456 const Index colOffsets[],
457 const Index inputCols[])
const
459 eigen_assert(colOffsets[0] == colOffsets[1]);
460 const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride,
461 patchOffsets[1] - colOffsets[1] * m_colStride};
462 eigen_assert(rowOffsets[0] <= rowOffsets[1]);
463 const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
465 if (inputRows[0] >= m_inputRows || inputRows[1] < 0)
468 return internal::pset1<Packet>(Scalar(0));
471 if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
474 const Index depth = patchId - patchOffsets[0] * patchDepth();
475 const Index inputIndex =
476 depth + inputRows[0] * m_rowInputStride + inputCols[0] * m_colInputStride + otherIndex;
477 return m_impl.template packet<Unaligned>(inputIndex);
479 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
487 template <
typename PacketT,
typename TensorEvaluatorT>
488 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
typename std::enable_if<
489 !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
490 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex)
const
492 const Index packetSize = internal::unpacket_traits<Packet>::size;
493 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
494 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
496 eigen_assert(!nonStandardPatches());
498 if ((patchDepth() % packetSize) == 0)
500 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
505 const Index patchOffsets[2] = {patchId / m_fastDimZero,
506 (patchId + packetSize - 1) / m_fastDimZero};
507 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
508 patchOffsets[1] / m_fastColStride};
509 const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
511 if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
514 return internal::pset1<Packet>(Scalar(0));
516 if (inputCols[0] == inputCols[1])
518 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
519 patchOffsets, colOffsets, inputCols);
521 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
535 template <
typename PacketT,
typename TensorEvaluatorT>
536 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
typename std::enable_if<
537 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
538 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex)
const
540 const Index packetSize = internal::unpacket_traits<PacketT>::size;
541 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
542 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
544 eigen_assert(!nonStandardPatches());
546 if ((patchDepth() % packetSize) == 0)
548 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
553 const Index patchOffsets[2] = {patchId / m_fastDimZero,
554 (patchId + packetSize - 1) / m_fastDimZero};
555 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
556 patchOffsets[1] / m_fastColStride};
557 const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
559 if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
562 return internal::pset1<PacketT>(Scalar(0));
564 if (inputCols[0] == inputCols[1])
566 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
567 patchOffsets, colOffsets, inputCols);
569 if (inputCols[1] == inputCols[0] + 1)
571 return loadPacketStandardFromTwoColumns(patchId, rowIndex, colIndex, otherIndex, patchOffsets,
574 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
578 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, Index colIndex,
579 Index otherIndex)
const
581 const Index packetSize = internal::unpacket_traits<Packet>::size;
582 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
583 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
585 eigen_assert(!nonStandardPatches());
586 eigen_assert((patchDepth() % packetSize) == 0);
588 const Index patchOffset = patchId / m_fastDimZero;
589 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
591 const Index colOffset = patchOffset / m_fastColStride;
592 const Index rowOffset = patchOffset - colOffset * m_colStride;
593 const Index inputCol = colIndex + colOffset;
594 const Index inputRow = rowIndex + rowOffset;
595 if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || inputRow >= m_inputRows)
598 return internal::pset1<Packet>(Scalar(0));
601 const Index depth = patchId - patchOffset * patchDepth();
602 const Index inputIndex =
603 depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
604 return m_impl.template packet<Unaligned>(inputIndex);
607 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId, Index rowIndex,
609 Index otherIndex)
const
611 const int packetSize = internal::unpacket_traits<Packet>::size;
613 std::remove_const_t<Scalar> values[packetSize];
614 for (
int i = 0; i < packetSize; ++i)
616 values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
618 Packet rslt = internal::pload<Packet>(values);
622 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void
623 computeBaseIndices(Index patchIndex, Index &rowIndex, Index &colIndex, Index &otherIndex)
const
625 const size_t NumInputDims =
626 array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
627 otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
628 const Index patch2DIndex =
629 (NumInputDims == 3) ? patchIndex : (patchIndex - otherIndex * m_num_patches);
630 otherIndex *= m_patchInputStride;
631 colIndex = patch2DIndex / m_fastOutputRows;
632 rowIndex = patch2DIndex - colIndex * m_outputRows;
633 colIndex = colIndex * m_col_strides - m_colPaddingLeft;
634 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
641 Index m_patch_row_stride;
642 Index m_patch_col_stride;
643 internal::TensorIntDivisor<Index> m_fastPatchRowStride;
644 internal::TensorIntDivisor<Index> m_fastPatchColStride;
646 Index m_patch_row_inflate_strides;
648 Index m_patch_col_inflate_strides;
651 internal::TensorIntDivisor<Index> m_fastInputRowStride;
652 internal::TensorIntDivisor<Index> m_fastInputColStride;
656 internal::TensorIntDivisor<Index> m_fastNumPatches;
657 internal::TensorIntDivisor<Index> m_fastColStride;
659 Index m_rowInputStride;
660 Index m_colInputStride;
661 Index m_patchInputStride;
672 Index m_in_row_strides;
673 Index m_in_col_strides;
675 Index m_rowPaddingTop;
676 Index m_colPaddingLeft;
678 internal::TensorIntDivisor<Index> m_fastOutputRows;
679 internal::TensorIntDivisor<Index> m_fastDimZero;
681 const TensorEvaluator<ArgType, Device> m_impl;