ONE - On-device Neural Engine
Loading...
Searching...
No Matches
eigen_spatial_convolutions-inl.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
19#define __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
20
22
23// Note this header is used in both TF and TFLite.
24namespace Eigen
25{
26
27namespace internal
28{
29
30// WARNING: Most of the code here implicitly assumes that the matrix is in
31// ColMajor layout. This is guaranteed by the tensor contraction (see
32// TensorContraction.h).
33//
34// Inside Eigen a tensor contraction is represented by a matrix multiplication.
35// We don't want to actually extract image patches and reshape the result into
36// a matrix (this involves allocating huge extra memory), so the patch
37// extraction and reshape operations are implicit.
38//
39// TensorContractionInputMapper takes a matrix index and returns the coefficient
40// (or the packet) of the "virtual tensor", that would be at that index if we
41// were to actually reshape the result of patch extraction.
42//
43// TensorContractionSubMapper provides a similar view into the "virtual matrix"
44// at the given vertical and horizontal offsets.
45//
46// "Virtual matrix" dimensions:
47// *0: kernelChannels * kernelRows * kernelCols;
48// 1: out_height * out_width; * OTHERS (e.g batches, etc...)
49//
50// *) extracted patches are continuous in memory (innermost dimension assuming
51// col major layout)
52//
53// With this dimensions:
54// row - offset within a single patch (in code: patchId)
55// col - index of the extracted patch (in code: patchIndex)
56// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
57//
58// TODO(ezhulenev): Consolidate this part of the code with the image patch
59// extraction code since they are both very similar.
60
61template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
62 typename Scalar_, typename Index, typename nocontract_t, typename contract_t, int Side,
63 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
64class TensorContractionInputMapper<
65 Scalar_, Index, Side,
66 TensorEvaluator<
67 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
68 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
69{
70public:
71 typedef Scalar_ Scalar;
72
73 typedef TensorContractionInputMapper<
74 Scalar, Index, Side,
75 TensorEvaluator<
76 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
77 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
79
80 typedef TensorContractionSubMapper<
81 Scalar, Index, Side,
82 TensorEvaluator<
83 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
84 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
86
89 typedef typename packet_traits<Scalar>::type Packet;
90
91 typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT;
92
93 EIGEN_DEVICE_FUNC
95 const TensorEvaluator<
96 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>
97 &tensor,
98 const nocontract_t &, const nocontract_t &, const contract_t &, const contract_t &)
99 : m_impl(tensor.impl().impl())
100 {
101 Index patch_rows;
102 Index patch_depth;
103 if (internal::traits<ArgType>::Layout == ColMajor)
104 {
105 patch_depth = tensor.impl().dimensions()[0];
106 patch_rows = tensor.impl().dimensions()[1];
107 m_patch_cols = tensor.impl().dimensions()[2];
108 m_num_patches = tensor.impl().dimensions()[3];
109 }
110 else
111 {
112 const size_t NumDims = tensor.impl().dimensions().size();
113 patch_depth = tensor.impl().dimensions()[NumDims - 1];
114 patch_rows = tensor.impl().dimensions()[NumDims - 2];
115 m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
116 m_num_patches = tensor.impl().dimensions()[NumDims - 4];
117 }
118
119 // Strides for navigating through the single patch.
120 m_patch_row_stride = patch_depth;
121 m_patch_col_stride = patch_rows * m_patch_row_stride;
122
123 m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
124 m_patch_col_inflate_strides = tensor.impl().colInflateStride();
125
126 m_colStride = patch_rows;
127
128 m_outputRows = tensor.impl().outputRows();
129 m_outputCols = tensor.impl().outputCols();
130 m_row_strides = tensor.impl().userRowStride();
131 m_col_strides = tensor.impl().userColStride();
132
133 m_in_row_strides = tensor.impl().userInRowStride();
134 m_in_col_strides = tensor.impl().userInColStride();
135
136 if (internal::traits<ArgType>::Layout == ColMajor)
137 {
138 m_inputRows = tensor.impl().impl().dimensions()[1];
139 m_inputCols = tensor.impl().impl().dimensions()[2];
140 }
141 else
142 {
143 const int NumDims = tensor.impl().impl().dimensions().size();
144 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
145 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
146 }
147
148 m_rowInputStride = patch_depth;
149 m_colInputStride = patch_depth * m_inputRows;
150 m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
151
152 m_rowPaddingTop = tensor.impl().rowPaddingTop();
153 m_colPaddingLeft = tensor.impl().colPaddingLeft();
154
155 m_fastPatchRowStride = internal::TensorIntDivisor<Index>(m_patch_row_stride);
156 m_fastPatchColStride = internal::TensorIntDivisor<Index>(m_patch_col_stride);
157 m_fastInputRowStride = internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
158 m_fastInputColStride = internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
159 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
160 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
161 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
162 m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
163 }
164
165 EIGEN_DEVICE_FUNC
166 TensorContractionInputMapper(const TensorContractionInputMapper &base_mapper)
167 : m_impl(base_mapper.m_impl)
168 {
169 m_patch_cols = base_mapper.m_patch_cols;
170 m_num_patches = base_mapper.m_num_patches;
171
172 m_patch_row_stride = base_mapper.m_patch_row_stride;
173 m_patch_col_stride = base_mapper.m_patch_col_stride;
174
175 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
176 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
177
178 m_colStride = base_mapper.m_colStride;
179
180 m_rowInputStride = base_mapper.m_rowInputStride;
181 m_colInputStride = base_mapper.m_colInputStride;
182 m_patchInputStride = base_mapper.m_patchInputStride;
183
184 m_inputRows = base_mapper.m_inputRows;
185 m_inputCols = base_mapper.m_inputCols;
186
187 m_outputRows = base_mapper.m_outputRows;
188 m_outputCols = base_mapper.m_outputCols;
189 m_row_strides = base_mapper.m_row_strides;
190 m_col_strides = base_mapper.m_col_strides;
191
192 m_in_row_strides = base_mapper.m_in_row_strides;
193 m_in_col_strides = base_mapper.m_in_col_strides;
194
195 m_rowPaddingTop = base_mapper.m_rowPaddingTop;
196 m_colPaddingLeft = base_mapper.m_colPaddingLeft;
197
198 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
199 m_fastPatchColStride = base_mapper.m_fastPatchColStride;
200 m_fastInputRowStride = base_mapper.m_fastInputRowStride;
201 m_fastInputColStride = base_mapper.m_fastInputColStride;
202 m_fastNumPatches = base_mapper.m_fastNumPatches;
203 m_fastColStride = base_mapper.m_fastColStride;
204 m_fastOutputRows = base_mapper.m_fastOutputRows;
205 m_fastDimZero = base_mapper.m_fastDimZero;
206 }
207
208 // If true, turns off some optimizations for loading packets since the image
209 // patches are "non-standard" such as there are non-trivial strides or
210 // inflations in the input.
211 EIGEN_DEVICE_FUNC
212 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const
213 {
214 return m_in_row_strides != 1 || m_in_col_strides != 1 || m_patch_row_inflate_strides != 1 ||
215 m_patch_col_inflate_strides != 1;
216 }
217
218 EIGEN_DEVICE_FUNC
219 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const
220 {
221 return SubMapper(*this, i, j);
222 }
223
224 EIGEN_DEVICE_FUNC
225 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const
226 {
227 return LinearMapper(*this, i, j);
228 }
229
230 EIGEN_DEVICE_FUNC
231 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const
232 {
233 Index rowIndex, colIndex, otherIndex;
234 computeBaseIndices(0, rowIndex, colIndex, otherIndex);
235 return loadCoeff(row, rowIndex, colIndex, otherIndex);
236 }
237
238 // Load the coefficient at the patchIndex location instead of the usual
239 // m_rowIndex,
240 // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
241 // EIGEN_DEVICE_FUNC
242 EIGEN_DEVICE_FUNC
243 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const
244 {
245 Index rowIndex, colIndex, otherIndex;
246 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
247 return loadCoeff(row, rowIndex, colIndex, otherIndex);
248 }
249
250 EIGEN_DEVICE_FUNC
251 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const
252 {
253 Index rowIndex, colIndex, otherIndex;
254 computeBaseIndices(0, rowIndex, colIndex, otherIndex);
255 return loadPacket(row, rowIndex, colIndex, otherIndex);
256 }
257
258 // Load the packet at the patchIndex location instead of the usual m_rowIndex,
259 // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
260 EIGEN_DEVICE_FUNC
261 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const
262 {
263 Index rowIndex, colIndex, otherIndex;
264 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
265 return loadPacket(row, rowIndex, colIndex, otherIndex);
266 }
267
268 EIGEN_DEVICE_FUNC
269 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device> &impl() const { return m_impl; }
270
271 EIGEN_DEVICE_FUNC
272 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
273 EIGEN_DEVICE_FUNC
274 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
275 EIGEN_DEVICE_FUNC
276 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
277
278private:
279 friend class TensorContractionSubMapper<
280 Scalar, Index, Side,
281 TensorEvaluator<
282 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
283 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>;
284
285 // Load coefficient from a patch specified by the "within patch offset"
286 // (patchId) and the precomputed indices of the first element of the patch.
287 EIGEN_DEVICE_FUNC
288 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, Index colIndex,
289 Index otherIndex) const
290 {
291 // Find the offset of the element wrt the location of the first element.
292 const Index patchOffset = patchId / m_fastDimZero;
293
294 const Index colOffset = patchOffset / m_fastColStride;
295 const Index inputCol = colIndex + colOffset * m_in_col_strides;
296 const Index origInputCol = (m_patch_col_inflate_strides == 1)
297 ? inputCol
298 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
299
300 const Index rowOffset = patchOffset - colOffset * m_colStride;
301 const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
302 const Index origInputRow = (m_patch_row_inflate_strides == 1)
303 ? inputRow
304 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
305 if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
306 origInputRow >= m_inputRows || (inputCol != origInputCol * m_patch_col_inflate_strides) ||
307 (inputRow != origInputRow * m_patch_row_inflate_strides))
308 {
309 return Scalar(0);
310 }
311 const Index depth = patchId - patchOffset * patchDepth();
312 const Index inputIndex =
313 depth + origInputRow * m_rowInputStride + origInputCol * m_colInputStride + otherIndex;
314 return m_impl.coeff(inputIndex);
315 }
316
317 // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
318 // and `in_strides` equal to 1 (template specialization without templates).
319 EIGEN_DEVICE_FUNC
320 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, Index colIndex,
321 Index otherIndex) const
322 {
323 eigen_assert(!nonStandardPatches());
324
325 // Find the offset of the element wrt the location of the first element.
326 const Index patchOffset = patchId / m_fastDimZero;
327 const Index colOffset = patchOffset / m_fastColStride;
328 const Index rowOffset = patchOffset - colOffset * m_colStride;
329 const Index inputCol = colIndex + colOffset;
330 const Index inputRow = rowIndex + rowOffset;
331 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || inputRow >= m_inputRows)
332 {
333 return Scalar(0);
334 }
335 const Index depth = patchId - patchOffset * patchDepth();
336 const Index inputIndex =
337 depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
338 return m_impl.coeff(inputIndex);
339 }
340
341 // Load packet from a patch specified by the "within patch offset"
342 // (patchId) and the precomputed indices of the first element of the patch.
343 EIGEN_DEVICE_FUNC
344 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, Index colIndex,
345 Index otherIndex) const
346 {
347 const Index packetSize = internal::unpacket_traits<Packet>::size;
348 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
349 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
350
351 if (nonStandardPatches())
352 {
353 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
354 }
355 typedef decltype(m_impl) TensorEvaluatorT;
356 return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex, colIndex, otherIndex);
357 }
358
359 // Helper function to load a 'partial' packet - this is the single column
360 // part of a packet that is split across two columns. In the 'partial' packet,
361 // the elements corresponding to the column (specified through colOffset) are
362 // loaded and the rest of the elements are zero-filled into the 'partial'
363 // packet. This function is called from loadPacketStandardFromTwoColumns().
364 // This code path is exercised only when the packet type supports masked load
365 // and when the partial packet load is available in the TensorEvaluator.
366 EIGEN_DEVICE_FUNC
367 EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(Index rowIndex, Index colIndex,
368 Index otherIndex, Index patchId,
369 const Index span[],
370 const Index patchOffsets[],
371 Index colOffset) const
372 {
373 const Index inputCol = colIndex + colOffset;
374 const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride,
375 patchOffsets[1] - colOffset * m_colStride};
376 const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
377
378 if (inputRows[0] >= m_inputRows || inputRows[1] < 0 || inputCol >= m_inputCols || inputCol < 0)
379 {
380 // Partial packet is all zeros
381 return internal::pset1<Packet>(Scalar(0));
382 }
383 else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
384 {
385 // From inputIndex-span[0], we need to load elements starting from index
386 // span[0] all the way upto (and including) span[1].
387 const Index depth = patchId - patchOffsets[0] * patchDepth();
388 const Index inputIndex =
389 depth + inputRows[0] * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
390 return m_impl.template partialPacket<Packet>(inputIndex - span[0],
391 mask<Packet>(span[0], span[1] + 1));
392 }
393 else
394 {
395 // Using slow path for this partial packet.
396 // We need to load elements starting from index span[0] all the way upto
397 // (and including) span[1]. We split this load into 3 parts:
398 // 0 : span[0]-1 - Zeros will be loaded for these indices
399 // span[0] : span[1] - Elements will be loaded here for these indices
400 // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
401 const Index packetSize = internal::unpacket_traits<Packet>::size;
402 EIGEN_ALIGN_MAX
403 std::remove_const_t<Scalar> values[packetSize];
404 for (int i = 0; i < span[0]; ++i)
405 values[i] = Scalar(0);
406 for (int i = span[0]; i < span[1] + 1; ++i)
407 values[i] = loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex);
408 for (int i = span[1] + 1; i < packetSize; ++i)
409 values[i] = Scalar(0);
410 return internal::pload<Packet>(values);
411 }
412 }
413
414 // Helper function to load a packet that is split across two columns.
415 // If required, this function is called from loadPacketStandard() when the
416 // packet type supports masked load and when the partial packet load is
417 // available in the TensorEvaluator.
418 EIGEN_DEVICE_FUNC
419 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns(Index patchId, Index rowIndex,
420 Index colIndex, Index otherIndex,
421 const Index patchOffsets[],
422 const Index colOffsets[]) const
423 {
424 eigen_assert(colOffsets[1] == colOffsets[0] + 1);
425 const Index packetSize = internal::unpacket_traits<Packet>::size;
426
427 // Packet to load will be split into 2 parts where each part spans a single
428 // column. First determine where to split.
429 const Index patchIdSplit = ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1;
430 const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
431
432 // patchIds[i]: patchId corresponding to partial packet i
433 // spans[i]: Start and end indices corresponding to the elements
434 // to be loaded for partial packet i
435 // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
436 const Index patchIds[2] = {patchId, patchIdSplit + 1};
437 const Index spans[2][2] = {{0, patchIdSplit - patchId},
438 {patchIdSplit - patchId + 1, packetSize - 1}};
439 const Index patchOffsets2Cols[2][2] = {{patchOffsets[0], patchOffsetSplit},
440 {patchOffsetSplit + 1, patchOffsets[1]}};
441
442 // Load partial packets and do bit-wise OR to generate required packet
443 return internal::por<Packet>(
444 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0], spans[0],
445 patchOffsets2Cols[0], colOffsets[0]),
446 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1], spans[1],
447 patchOffsets2Cols[1], colOffsets[1]));
448 }
449
450 // Helper function to load a packet that is present in a single columns.
451 // If required, this function is called from loadPacketStandard().
452 EIGEN_DEVICE_FUNC
453 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn(Index patchId, Index rowIndex,
454 Index colIndex, Index otherIndex,
455 const Index patchOffsets[],
456 const Index colOffsets[],
457 const Index inputCols[]) const
458 {
459 eigen_assert(colOffsets[0] == colOffsets[1]);
460 const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride,
461 patchOffsets[1] - colOffsets[1] * m_colStride};
462 eigen_assert(rowOffsets[0] <= rowOffsets[1]);
463 const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
464
465 if (inputRows[0] >= m_inputRows || inputRows[1] < 0)
466 {
467 // all zeros
468 return internal::pset1<Packet>(Scalar(0)); // all zeros
469 }
470
471 if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
472 {
473 // no padding
474 const Index depth = patchId - patchOffsets[0] * patchDepth();
475 const Index inputIndex =
476 depth + inputRows[0] * m_rowInputStride + inputCols[0] * m_colInputStride + otherIndex;
477 return m_impl.template packet<Unaligned>(inputIndex);
478 }
479 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
480 }
481
482 // Load standard packet from a patch specified by the "within patch offset"
483 // (patchId) and the precomputed indices of the first element of the patch.
484 // This function will be called if partial packet loading is not available
485 // for the TensorEvaluator or if the packet type does not support masked
486 // load.
487 template <typename PacketT, typename TensorEvaluatorT>
488 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
489 !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
490 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
491 {
492 const Index packetSize = internal::unpacket_traits<Packet>::size;
493 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
494 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
495
496 eigen_assert(!nonStandardPatches());
497
498 if ((patchDepth() % packetSize) == 0)
499 {
500 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
501 }
502
503 // Offsets and input calculation here are identical to
504 // loadCoeffStandard(...), but repeated twice.
505 const Index patchOffsets[2] = {patchId / m_fastDimZero,
506 (patchId + packetSize - 1) / m_fastDimZero};
507 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
508 patchOffsets[1] / m_fastColStride};
509 const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
510
511 if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
512 {
513 // all zeros
514 return internal::pset1<Packet>(Scalar(0));
515 }
516 if (inputCols[0] == inputCols[1])
517 {
518 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
519 patchOffsets, colOffsets, inputCols);
520 }
521 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
522 }
523
524 // Load standard packet from a patch specified by the "within patch offset"
525 // (patchId) and the precomputed indices of the first element of the patch.
526 // This function will be called if partial packet loading is available for
527 // the TensorEvaluator and if the packet type supports masked load.
528 // The only difference between this and the other case is that if the packet
529 // to load is split across two columns, then in this case instead of going to
530 // the slow (element-by-element) load, we load two packets - each containing
531 // elements from one of the columns (rest of the elements of the packets are
532 // zeroes), and then combine these two packets to generate the required
533 // packet. The idea is to enable fast load (if possible) of these 'partial'
534 // packets.
535 template <typename PacketT, typename TensorEvaluatorT>
536 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
537 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
538 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
539 {
540 const Index packetSize = internal::unpacket_traits<PacketT>::size;
541 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
542 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
543
544 eigen_assert(!nonStandardPatches());
545
546 if ((patchDepth() % packetSize) == 0)
547 {
548 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
549 }
550
551 // Offsets and input calculation here are identical to
552 // loadCoeffStandard(...), but repeated twice.
553 const Index patchOffsets[2] = {patchId / m_fastDimZero,
554 (patchId + packetSize - 1) / m_fastDimZero};
555 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
556 patchOffsets[1] / m_fastColStride};
557 const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
558
559 if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
560 {
561 // all zeros
562 return internal::pset1<PacketT>(Scalar(0));
563 }
564 if (inputCols[0] == inputCols[1])
565 {
566 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
567 patchOffsets, colOffsets, inputCols);
568 }
569 if (inputCols[1] == inputCols[0] + 1)
570 {
571 return loadPacketStandardFromTwoColumns(patchId, rowIndex, colIndex, otherIndex, patchOffsets,
572 colOffsets);
573 }
574 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
575 }
576
577 EIGEN_DEVICE_FUNC
578 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, Index colIndex,
579 Index otherIndex) const
580 {
581 const Index packetSize = internal::unpacket_traits<Packet>::size;
582 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
583 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
584
585 eigen_assert(!nonStandardPatches());
586 eigen_assert((patchDepth() % packetSize) == 0);
587 // Find the offset of the element wrt the location of the first element.
588 const Index patchOffset = patchId / m_fastDimZero;
589 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
590
591 const Index colOffset = patchOffset / m_fastColStride;
592 const Index rowOffset = patchOffset - colOffset * m_colStride;
593 const Index inputCol = colIndex + colOffset;
594 const Index inputRow = rowIndex + rowOffset;
595 if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || inputRow >= m_inputRows)
596 {
597 // all zeros
598 return internal::pset1<Packet>(Scalar(0));
599 }
600 // no padding
601 const Index depth = patchId - patchOffset * patchDepth();
602 const Index inputIndex =
603 depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
604 return m_impl.template packet<Unaligned>(inputIndex);
605 }
606
607 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId, Index rowIndex,
608 Index colIndex,
609 Index otherIndex) const
610 {
611 const int packetSize = internal::unpacket_traits<Packet>::size;
612 EIGEN_ALIGN_MAX
613 std::remove_const_t<Scalar> values[packetSize];
614 for (int i = 0; i < packetSize; ++i)
615 {
616 values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
617 }
618 Packet rslt = internal::pload<Packet>(values);
619 return rslt;
620 }
621
622 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
623 computeBaseIndices(Index patchIndex, Index &rowIndex, Index &colIndex, Index &otherIndex) const
624 {
625 const size_t NumInputDims =
626 array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
627 otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
628 const Index patch2DIndex =
629 (NumInputDims == 3) ? patchIndex : (patchIndex - otherIndex * m_num_patches);
630 otherIndex *= m_patchInputStride;
631 colIndex = patch2DIndex / m_fastOutputRows;
632 rowIndex = patch2DIndex - colIndex * m_outputRows;
633 colIndex = colIndex * m_col_strides - m_colPaddingLeft;
634 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
635 }
636
637 Index m_patch_cols; // number of columns in the patch
638 Index m_num_patches; // number of patches to extract.
639
640 // Strides for navigating through the single patch.
641 Index m_patch_row_stride;
642 Index m_patch_col_stride;
643 internal::TensorIntDivisor<Index> m_fastPatchRowStride;
644 internal::TensorIntDivisor<Index> m_fastPatchColStride;
645
646 Index m_patch_row_inflate_strides; // the strides for row inflation in the
647 // image patch
648 Index m_patch_col_inflate_strides; // the strides for col inflation in the
649 // image patch
650 // Fast representation of inflation strides.
651 internal::TensorIntDivisor<Index> m_fastInputRowStride;
652 internal::TensorIntDivisor<Index> m_fastInputColStride;
653
654 Index m_otherStride;
655 Index m_colStride;
656 internal::TensorIntDivisor<Index> m_fastNumPatches;
657 internal::TensorIntDivisor<Index> m_fastColStride;
658
659 Index m_rowInputStride; // row stride in the input tensor
660 Index m_colInputStride; // col stride in the input tensor
661 Index m_patchInputStride; // patch stride in the input tensor
662
663 Index m_inputRows; // Number of rows in the input tensor
664 Index m_inputCols; // Number of cols in the input tensor
665
666 Index m_outputRows; // Number of convolution output rows
667 Index m_outputCols; // Number of convolution output column
668
669 Index m_row_strides; // User specified row stride
670 Index m_col_strides; // User specified col stride
671
672 Index m_in_row_strides; // User specified input row stride
673 Index m_in_col_strides; // User specified input col stride
674
675 Index m_rowPaddingTop; // Row padding
676 Index m_colPaddingLeft; // Column padding
677
678 internal::TensorIntDivisor<Index> m_fastOutputRows;
679 internal::TensorIntDivisor<Index> m_fastDimZero;
680
681 const TensorEvaluator<ArgType, Device> m_impl;
682};
683
684template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
685 typename Scalar, typename Index, typename nocontract_t, typename contract_t, int Side,
686 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
687class TensorContractionSubMapper<
688 Scalar, Index, Side,
689 TensorEvaluator<
690 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
691 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
692{
693public:
694 typedef typename packet_traits<Scalar>::type Packet;
695 typedef typename packet_traits<Scalar>::half HalfPacket;
696
697 typedef TensorContractionInputMapper<
698 Scalar, Index, Side,
699 TensorEvaluator<
700 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
701 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
703
704 typedef TensorContractionSubMapper<
705 Scalar, Index, Side,
706 TensorEvaluator<
707 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
708 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
710
712
713 typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT;
714
715 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const ParentMapper &base_mapper,
716 Index vert_offset,
717 Index horiz_offset)
718 : m_depth_offset(vert_offset), m_col_offset(horiz_offset), m_base_mapper(base_mapper)
719 {
720 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
721 }
722 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const Self &base_mapper,
723 Index vert_offset,
724 Index horiz_offset)
725 : m_depth_offset(vert_offset + base_mapper.m_depth_offset),
726 m_col_offset(horiz_offset + base_mapper.m_col_offset),
727 m_base_mapper(base_mapper.m_base_mapper)
728 {
729 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
730 }
731 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const
732 {
733 return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
734 }
735 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const
736 {
737 return m_base_mapper(i + m_depth_offset, j + m_col_offset);
738 }
739
740 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const
741 {
742 return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
743 }
744 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const
745 {
746 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, j + m_col_offset);
747 }
748 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i) const
749 {
750 return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, m_colIndex,
751 m_otherIndex);
752 }
753
754 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const
755 {
756 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
757 }
758 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i) const
759 {
760 typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
761 return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
762 i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
763 }
764 template <typename Packet> EIGEN_DEVICE_FUNC bool aligned(Index) const { return false; }
765
766 EIGEN_DEVICE_FUNC
767 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { return m_base_mapper.nonStandardPatches(); }
768
769 // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
770 // index respectively that fits into the peeled_k elements starting at
771 // m_depth_offset.
772
773 EIGEN_DEVICE_FUNC
774 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const
775 {
776 const Index max_col =
777 (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / fastPatchColStride();
778 return std::min<Index>(1 + max_col, patchCols());
779 }
780
781 EIGEN_DEVICE_FUNC
782 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, const Index col) const
783 {
784 const Index max_row =
785 (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - col * patchColStride()) /
786 fastPatchRowStride();
787 return std::min<Index>(1 + max_row, patchRows());
788 }
789
790 EIGEN_DEVICE_FUNC
791 EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, Index row) const
792 {
793 const Index max_depth = m_depth_offset + peeled_k - //
794 col * patchColStride() - //
795 row * patchRowStride();
796 return std::min<Index>(max_depth, patchDepth());
797 }
798
799 // MaxDepth uses only the remaining number of elements in the peeled_k.
800 EIGEN_DEVICE_FUNC
801 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, const Index start_depth) const
802 {
803 return std::min<Index>(start_depth + num_elements, patchDepth());
804 }
805
806 // Every register matters in this code, so sometimes to prevent register
807 // spilling, instead of the variable that you would expect to see, we use
808 // another one, that is guaranteed to have the same value. E.g. patch depth is
809 // always the same as input depth, and it's also the same as input row stride.
810 // Bunch of other parameters have similar relations.
811
812 typedef internal::TensorIntDivisor<Index> IndexDivisor;
813
814 EIGEN_DEVICE_FUNC
815 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_base_mapper.m_rowInputStride; }
816 EIGEN_DEVICE_FUNC
817 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_base_mapper.m_colStride; }
818 EIGEN_DEVICE_FUNC
819 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_base_mapper.m_patch_cols; }
820
821 EIGEN_DEVICE_FUNC
822 EIGEN_ALWAYS_INLINE Index patchRowStride() const
823 {
824 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
825 "Patch depth must be equal to patch row stride.");
826 return patchDepth();
827 }
828 EIGEN_DEVICE_FUNC
829 EIGEN_ALWAYS_INLINE Index patchColStride() const { return m_base_mapper.m_patch_col_stride; }
830
831 EIGEN_DEVICE_FUNC
832 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const
833 {
834 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
835 "Patch depth must be equal to patch row stride.");
836 return m_base_mapper.m_fastDimZero; // patch_depth
837 }
838 EIGEN_DEVICE_FUNC
839 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const
840 {
841 return m_base_mapper.m_fastPatchColStride;
842 }
843
844 EIGEN_DEVICE_FUNC
845 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, const Index baseIndex) const
846 {
847 const Index inputIndex = depth + baseIndex;
848 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
849 }
850 EIGEN_DEVICE_FUNC
851 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, const Index baseIndex) const
852 {
853 const Index inputIndex = depth + baseIndex;
854 return m_base_mapper.m_impl.coeff(inputIndex);
855 }
856 template <typename PacketT = Packet>
857 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
859 partialPacketNoPadding(const Index depth, const Index baseIndex, Index num_coeffs) const
860 {
861 const Index inputIndex = depth + baseIndex;
862 return m_base_mapper.m_impl.template partialPacket<PacketT>(inputIndex,
863 mask<PacketT>(0, num_coeffs));
864 }
865 EIGEN_DEVICE_FUNC
866 EIGEN_ALWAYS_INLINE bool hasPadding() const
867 {
868 // TODO(ezhulenev): It does seems that for inflated filter it's still
869 // possible to guarantee "no padding or skipping" for non-standard packing.
870 if (nonStandardPatches())
871 return true;
872
873 // Non zero padding before.
874 if (m_base_mapper.m_rowPaddingTop > 0)
875 return true;
876 if (m_base_mapper.m_colPaddingLeft > 0)
877 return true;
878
879 // Non zero padding after in rows.
880 const Index last_row = (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides;
881 if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows)
882 return true;
883
884 // Non zero padding after in cols.
885 const Index last_col = (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides;
886 if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols)
887 return true;
888
889 return false;
890 }
891 EIGEN_DEVICE_FUNC
892 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const
893 {
894 const Index r = m_rowIndex + row;
895 return r < 0 || r >= m_base_mapper.m_inputRows;
896 }
897 EIGEN_DEVICE_FUNC
898 EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, const Index last_row) const
899 {
900 return m_rowIndex + first_row < 0 || m_rowIndex + last_row >= m_base_mapper.m_inputRows;
901 }
902 EIGEN_DEVICE_FUNC
903 EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row, Index *orig_row) const
904 {
905 eigen_assert(nonStandardPatches());
906
907 const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides;
908 *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1)
909 ? input_row
910 : ((input_row >= 0) ? (input_row / m_base_mapper.m_fastInputRowStride) : 0);
911
912 return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) ||
913 (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides);
914 }
915 EIGEN_DEVICE_FUNC
916 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const
917 {
918 const Index c = m_colIndex + col;
919 return c < 0 || c >= m_base_mapper.m_inputCols;
920 }
921 EIGEN_DEVICE_FUNC
922 EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col, Index *orig_col) const
923 {
924 eigen_assert(nonStandardPatches());
925
926 const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides;
927 *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1)
928 ? input_col
929 : ((input_col >= 0) ? (input_col / m_base_mapper.m_fastInputColStride) : 0);
930
931 return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) ||
932 (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides);
933 }
934 EIGEN_DEVICE_FUNC
935 EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const
936 {
937 const Index r = m_rowIndex + row;
938 const Index c = m_colIndex + col;
939 return r * m_base_mapper.m_rowInputStride + c * m_base_mapper.m_colInputStride + m_otherIndex;
940 }
941 // Compute a base index when original input row and column were precomputed
942 // using padOrSkipRow and padOrSkipCol. Used only for non standard patches.
943 EIGEN_DEVICE_FUNC
944 EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row, const Index orig_col) const
945 {
946 return orig_row * m_base_mapper.m_rowInputStride + orig_col * m_base_mapper.m_colInputStride +
947 m_otherIndex;
948 }
949
950 EIGEN_DEVICE_FUNC
951 EIGEN_ALWAYS_INLINE Index rowStride() const { return m_base_mapper.m_row_strides; }
952 EIGEN_DEVICE_FUNC
953 EIGEN_ALWAYS_INLINE Index colStride() const { return m_base_mapper.m_col_strides; }
954
955 EIGEN_DEVICE_FUNC
956 EIGEN_ALWAYS_INLINE Index rowOffset() const
957 {
958 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
959 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
960 return patchOffset - colOffset * m_base_mapper.m_colStride;
961 }
962
963 EIGEN_DEVICE_FUNC
964 EIGEN_ALWAYS_INLINE Index colOffset() const
965 {
966 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
967 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
968 return colOffset;
969 }
970
971 EIGEN_DEVICE_FUNC
972 EIGEN_ALWAYS_INLINE Index depthOffset() const { return m_depth_offset % patchDepth(); }
973
974 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const
975 {
976 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
977 }
978
979private:
980 Index m_depth_offset; // First row in the input matrix
981 Index m_col_offset; // First col in the input matrix
982
983 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
984 // indices for the first element in a patch specified by col_offset
985 // (see computeBaseIndices(...) for details).
986 Index m_rowIndex;
987 Index m_colIndex;
988 Index m_otherIndex;
989
990 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference
991 // performs better in benchmarks.
992};
993
994// Arrange a block of the right input matrix (in our case it's always a "virtual
995// matrix" constructed from extracted image patches) in contiguous memory.
996//
997// Given column major input (A0 beside A1 in memory):
998// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
999// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
1000// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
1001// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
1002// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
1003// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
1004// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
1005// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
1006// A8 ...
1007// ...
1008//
1009// *) A, B, C, ... - patches extracted from the original input.
1010// *) A0, A1, A2 ... - values from the same patch at different offsets.
1011//
1012// The traversal (packed rhs memory) order (B0 besides A0 in memory):
1013// A0 B0 C0 D0 A1 B1 C1 D1 ...
1014// E0 F0 G0 H0 E1 F1 G1 H1 ...
1015// ...
1016// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1017//
1018// This traversal order must be the same as in default gemm_pack_rhs defined in
1019// GeneralBlockPanelKernel.h.
1020//
1021// *) nr - number of registers along the 'n' dimension.
1022// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1023// Multiplication" paper.
1024template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1025 typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1026 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1027 int nr>
1028struct gemm_pack_rhs<
1029 Scalar, Index,
1030 TensorContractionSubMapper<
1031 Scalar, Index, Rhs,
1032 TensorEvaluator<
1033 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1034 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1035 nr, ColMajor, false, false>
1036{
1037 typedef TensorContractionSubMapper<
1038 Scalar, Index, Rhs,
1039 TensorEvaluator<
1040 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1041 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
1044 typedef typename packet_traits<Scalar>::type Packet;
1045
1046 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1047
1048 EIGEN_DEVICE_FUNC
1049 EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1050 Index stride = 0, Index offset = 0) const
1051 {
1052 eigen_assert(stride == 0);
1053 eigen_assert(offset == 0);
1054 (void)stride;
1055 (void)offset;
1056
1057 const Index packet_cols4 = (cols / 4) * 4;
1058 const Index peeled_k = (depth / packet_size) * packet_size;
1059 const bool non_standard_patches = rhs.nonStandardPatches();
1060
1061 for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1062 {
1063 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1064 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1065 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1066 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1067
1068 Index k = 0;
1069 if ((packet_size % 4) == 0 && !non_standard_patches)
1070 {
1071 // FAST PATH:
1072 // Iterate over patch columns and rows, if we know that a single
1073 // packet do not span across multiple rows or columns.
1074 if ((rhs.patchDepth() % packet_size) == 0)
1075 {
1076 const Index start_col = rhs.colOffset();
1077 const Index max_col = rhs.maxCol(peeled_k);
1078
1079 for (Index c = start_col; c < max_col; ++c)
1080 {
1081 eigen_assert(k <= peeled_k);
1082
1083 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1084 const Index max_row = rhs.maxRow(peeled_k, c);
1085
1086 const bool pad_col0 = dm0.padCol(c);
1087 const bool pad_col1 = dm1.padCol(c);
1088 const bool pad_col2 = dm2.padCol(c);
1089 const bool pad_col3 = dm3.padCol(c);
1090
1091 // Check if we can squeeze reads along the `row` and `depth`
1092 // dimensions (two innermost dimensions).
1093 if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && //
1094 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1095 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1096 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1097 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1))
1098 {
1099 // Compute how many elements we can squeeze read.
1100 const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
1101
1102 // Upper bound for the number of elements in the depth dimension
1103 // that we can squeeze read.
1104 const Index squeeze_length = (max_row - start_row) * rhs.patchDepth() - start_depth;
1105
1106 // Do not overshoot beyond the block size.
1107 const Index max_depth = start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1108 eigen_assert((max_depth - start_depth) % packet_size == 0);
1109
1110 const Index idx0 = dm0.baseIndex(start_row, c);
1111 const Index idx1 = dm1.baseIndex(start_row, c);
1112 const Index idx2 = dm2.baseIndex(start_row, c);
1113 const Index idx3 = dm3.baseIndex(start_row, c);
1114
1115 for (Index d = start_depth; d < max_depth; d += packet_size)
1116 {
1117 eigen_assert(k < peeled_k);
1118 PacketBlock<Packet, 4> kernel;
1119 kernel.packet[0] = rhs.packetNoPadding(d, idx0);
1120 kernel.packet[1] = rhs.packetNoPadding(d, idx1);
1121 kernel.packet[2] = rhs.packetNoPadding(d, idx2);
1122 kernel.packet[3] = rhs.packetNoPadding(d, idx3);
1123 ptranspose(kernel);
1124 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1125 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1126 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1127 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1128 block += 4 * packet_size;
1129 k += packet_size;
1130 }
1131
1132 // Go to the next column.
1133 continue;
1134 }
1135
1136 // If we can't squeeze reads, process rows one by one.
1137 for (Index r = start_row; r < max_row; ++r)
1138 {
1139 eigen_assert(k <= peeled_k);
1140
1141 const bool pad0 = pad_col0 || dm0.padRow(r);
1142 const bool pad1 = pad_col1 || dm1.padRow(r);
1143 const bool pad2 = pad_col2 || dm2.padRow(r);
1144 const bool pad3 = pad_col3 || dm3.padRow(r);
1145
1146 const Index idx0 = dm0.baseIndex(r, c);
1147 const Index idx1 = dm1.baseIndex(r, c);
1148 const Index idx2 = dm2.baseIndex(r, c);
1149 const Index idx3 = dm3.baseIndex(r, c);
1150
1151 const Index start_depth =
1152 ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
1153 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1154 eigen_assert((max_depth - start_depth) % packet_size == 0);
1155
1156 for (Index d = start_depth; d < max_depth; d += packet_size)
1157 {
1158 eigen_assert(k < peeled_k);
1159 PacketBlock<Packet, 4> kernel;
1160 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx0);
1161 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx1);
1162 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx2);
1163 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx3);
1164 ptranspose(kernel);
1165 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1166 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1167 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1168 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1169 block += 4 * packet_size;
1170 k += packet_size;
1171 }
1172 }
1173 }
1174
1175 // The loop above should fill peeled_k elements.
1176 eigen_assert(peeled_k == k);
1177 }
1178 else
1179 {
1180 for (; k < peeled_k; k += packet_size)
1181 {
1182 PacketBlock<Packet, 4> kernel;
1183 kernel.packet[0] = dm0.loadPacketStandard(k);
1184 kernel.packet[1] = dm1.loadPacketStandard(k);
1185 kernel.packet[2] = dm2.loadPacketStandard(k);
1186 kernel.packet[3] = dm3.loadPacketStandard(k);
1187 ptranspose(kernel);
1188 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1189 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1190 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1191 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1192 block += 4 * packet_size;
1193 }
1194 }
1195 }
1196
1197 // Copy the remaining coefficients of the column block after the peeled_k.
1198 if (!rhs.nonStandardPatches())
1199 {
1200 for (; k < depth; k++)
1201 {
1202 block[0] = dm0.loadCoeffStandard(k);
1203 block[1] = dm1.loadCoeffStandard(k);
1204 block[2] = dm2.loadCoeffStandard(k);
1205 block[3] = dm3.loadCoeffStandard(k);
1206 block += 4;
1207 }
1208 }
1209 else
1210 {
1211 for (; k < depth; k++)
1212 {
1213 block[0] = dm0(k);
1214 block[1] = dm1(k);
1215 block[2] = dm2(k);
1216 block[3] = dm3(k);
1217 block += 4;
1218 }
1219 }
1220 }
1221
1222 // copy the remaining columns one at a time (nr==1)
1223 for (Index j2 = packet_cols4; j2 < cols; ++j2)
1224 {
1225 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1226 for (Index k = 0; k < depth; k++)
1227 {
1228 *block = dm0(k);
1229 block += 1;
1230 }
1231 }
1232 }
1233};
1234
1235// Template specialization for packet_size = 2. We must special-case packet
1236// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1237template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1238 typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1239 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
1240struct gemm_pack_rhs<
1241 Scalar, Index,
1242 TensorContractionSubMapper<
1243 Scalar, Index, Rhs,
1244 TensorEvaluator<
1245 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1246 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1247 nr, ColMajor, false, false>
1248{
1249 typedef TensorContractionSubMapper<
1250 Scalar, Index, Rhs,
1251 TensorEvaluator<
1252 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1253 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment>
1256 typedef typename packet_traits<Scalar>::type Packet;
1257
1258 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1259
1260 EIGEN_DEVICE_FUNC
1261 EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1262 Index stride = 0, Index offset = 0) const
1263 {
1264 eigen_assert(stride == 0);
1265 eigen_assert(offset == 0);
1266
1267 (void)stride;
1268 (void)offset;
1269
1270 const int packet_size = 2;
1271 const Index packet_cols4 = (cols / 4) * 4;
1272 const Index peeled_k = (depth / packet_size) * packet_size;
1273 const bool non_standard_patches = rhs.nonStandardPatches();
1274
1275 for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1276 {
1277 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1278 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1279 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1280 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1281
1282 Index k = 0;
1283 if (!non_standard_patches)
1284 {
1285 // FAST PATH:
1286 // Iterate over patch columns and rows if we know that a single
1287 // packet do not span across multiple rows or columns.
1288 if ((rhs.patchDepth() % packet_size) == 0)
1289 {
1290 const Index start_col = rhs.colOffset();
1291 const Index max_col = rhs.maxCol(peeled_k);
1292
1293 for (Index c = start_col; c < max_col; ++c)
1294 {
1295 eigen_assert(k <= peeled_k);
1296
1297 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1298 const Index max_row = rhs.maxRow(peeled_k, c);
1299
1300 const bool pad_col0 = dm0.padCol(c);
1301 const bool pad_col1 = dm1.padCol(c);
1302 const bool pad_col2 = dm2.padCol(c);
1303 const bool pad_col3 = dm3.padCol(c);
1304
1305 // We can squeeze reads along the `row` and `depth` dimensions if
1306 // the row stride is `1`, which means that `row` and `depth`
1307 // dimensions are contiguous (two innermost dimensions).
1308 if (rhs.rowStride() == 1 && //
1309 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && //
1310 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1311 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1312 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1313 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1))
1314 {
1315 // Compute how many elements we can squeeze read.
1316 const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
1317
1318 // Upper bound for the number of elements in the depth dimension
1319 // that we can squeeze read.
1320 const Index squeeze_length = (max_row - start_row) * rhs.patchDepth() - start_depth;
1321
1322 // Do not overshoot beyond the block size.
1323 const Index max_depth = start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1324 eigen_assert((max_depth - start_depth) % packet_size == 0);
1325
1326 const Index idx0 = dm0.baseIndex(start_row, c);
1327 const Index idx1 = dm1.baseIndex(start_row, c);
1328 const Index idx2 = dm2.baseIndex(start_row, c);
1329 const Index idx3 = dm3.baseIndex(start_row, c);
1330
1331 for (Index d = start_depth; d < max_depth; d += packet_size)
1332 {
1333 PacketBlock<Packet, 2> kernel0;
1334 PacketBlock<Packet, 2> kernel1;
1335 kernel0.packet[0] = rhs.packetNoPadding(d, idx0);
1336 kernel0.packet[1] = rhs.packetNoPadding(d, idx1);
1337 kernel1.packet[0] = rhs.packetNoPadding(d, idx2);
1338 kernel1.packet[1] = rhs.packetNoPadding(d, idx3);
1339 ptranspose(kernel0);
1340 ptranspose(kernel1);
1341 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1342 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1343 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1344 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1345 block += 4 * packet_size;
1346 k += packet_size;
1347 }
1348
1349 // Go to the next column.
1350 continue;
1351 }
1352
1353 // If we can't squeeze reads, process rows one by one.
1354 for (Index r = start_row; r < max_row; ++r)
1355 {
1356 eigen_assert(k <= peeled_k);
1357
1358 const bool pad0 = pad_col0 || dm0.padRow(r);
1359 const bool pad1 = pad_col1 || dm1.padRow(r);
1360 const bool pad2 = pad_col2 || dm2.padRow(r);
1361 const bool pad3 = pad_col3 || dm3.padRow(r);
1362
1363 const Index idx0 = dm0.baseIndex(r, c);
1364 const Index idx1 = dm1.baseIndex(r, c);
1365 const Index idx2 = dm2.baseIndex(r, c);
1366 const Index idx3 = dm3.baseIndex(r, c);
1367
1368 const Index start_depth =
1369 ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
1370 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1371 eigen_assert((max_depth - start_depth) % packet_size == 0);
1372
1373 for (Index d = start_depth; d < max_depth; d += packet_size)
1374 {
1375 eigen_assert(k < peeled_k);
1376 PacketBlock<Packet, 2> kernel0;
1377 PacketBlock<Packet, 2> kernel1;
1378 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx0);
1379 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx1);
1380 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx2);
1381 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx3);
1382 ptranspose(kernel0);
1383 ptranspose(kernel1);
1384 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1385 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1386 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1387 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1388 block += 4 * packet_size;
1389 k += packet_size;
1390 }
1391 }
1392 }
1393
1394 // The loop above should fill peeled_k elements.
1395 eigen_assert(peeled_k == k);
1396 }
1397 else
1398 {
1399 // Packet can span multiple rows or columns, so we have to go
1400 // though the slower "standard" path.
1401 for (; k < peeled_k; k += packet_size)
1402 {
1403 PacketBlock<Packet, 2> kernel0;
1404 PacketBlock<Packet, 2> kernel1;
1405 kernel0.packet[0] = dm0.loadPacketStandard(k);
1406 kernel0.packet[1] = dm1.loadPacketStandard(k);
1407 kernel1.packet[0] = dm2.loadPacketStandard(k);
1408 kernel1.packet[1] = dm3.loadPacketStandard(k);
1409 ptranspose(kernel0);
1410 ptranspose(kernel1);
1411 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1412 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1413 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1414 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1415 block += 4 * packet_size;
1416 }
1417 }
1418 }
1419
1420 // Copy the remaining coefficients of the column block after the peeled_k.
1421 if (!non_standard_patches)
1422 {
1423 for (; k < depth; k++)
1424 {
1425 block[0] = dm0.loadCoeffStandard(k);
1426 block[1] = dm1.loadCoeffStandard(k);
1427 block[2] = dm2.loadCoeffStandard(k);
1428 block[3] = dm3.loadCoeffStandard(k);
1429 block += 4;
1430 }
1431 }
1432 else
1433 {
1434 for (; k < depth; k++)
1435 {
1436 block[0] = dm0(k);
1437 block[1] = dm1(k);
1438 block[2] = dm2(k);
1439 block[3] = dm3(k);
1440 block += 4;
1441 }
1442 }
1443 }
1444
1445 // Copy the remaining columns one at a time (nr==1).
1446 for (Index j2 = packet_cols4; j2 < cols; ++j2)
1447 {
1448 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1449 for (Index k = 0; k < depth; k++)
1450 {
1451 *block = dm0(k);
1452 block += 1;
1453 }
1454 }
1455 }
1456};
1457
1458// Special case for non-vectorized types such as float16.
1459template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1460 typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1461 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
1462struct gemm_pack_rhs<
1463 Scalar, Index,
1464 TensorContractionSubMapper<
1465 Scalar, Index, Rhs,
1466 TensorEvaluator<
1467 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1468 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1469 nr, ColMajor, false, false>
1470{
1471 typedef TensorContractionSubMapper<
1472 Scalar, Index, Rhs,
1473 TensorEvaluator<
1474 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>, Device>,
1475 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>
1478
1479 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1480
1481 EIGEN_DEVICE_FUNC
1482 EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1483 Index stride = 0, Index offset = 0) const
1484 {
1485 eigen_assert(stride == 0);
1486 eigen_assert(offset == 0);
1487
1488 (void)offset;
1489 (void)stride;
1490
1491 const Index packet_cols4 = (cols / 4) * 4;
1492
1493 for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1494 {
1495 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1496 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1497 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1498 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1499
1500 if (!rhs.nonStandardPatches())
1501 {
1502 for (Index k = 0; k < depth; k++)
1503 {
1504 block[0] = dm0.loadCoeffStandard(k);
1505 block[1] = dm1.loadCoeffStandard(k);
1506 block[2] = dm2.loadCoeffStandard(k);
1507 block[3] = dm3.loadCoeffStandard(k);
1508 block += 4;
1509 }
1510 }
1511 else
1512 {
1513 for (Index k = 0; k < depth; k++)
1514 {
1515 block[0] = dm0(k);
1516 block[1] = dm1(k);
1517 block[2] = dm2(k);
1518 block[3] = dm3(k);
1519 block += 4;
1520 }
1521 }
1522 }
1523
1524 // Copy the remaining columns one at a time (nr==1).
1525 for (Index j2 = packet_cols4; j2 < cols; ++j2)
1526 {
1527 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1528 for (Index k = 0; k < depth; k++)
1529 {
1530 *block = dm0(k);
1531 block += 1;
1532 }
1533 }
1534 }
1535};
1536} // end namespace internal
1537
1569template <typename Input, typename Kernel, typename OutputKernel = const NoOpOutputKernel>
1570EIGEN_ALWAYS_INLINE static const std::conditional_t<
1571 internal::traits<Input>::Layout == ColMajor,
1572 TensorReshapingOp<
1573 const DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>,
1574 const TensorContractionOp<
1575 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1576 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1577 const Kernel>,
1578 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1579 const TensorImagePatchOp<Dynamic, Dynamic, const Input>>,
1580 const OutputKernel>>,
1581 TensorReshapingOp<
1582 const DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>,
1583 const TensorContractionOp<
1584 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1585 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1586 const TensorImagePatchOp<Dynamic, Dynamic, const Input>>,
1587 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1588 const Kernel>,
1589 const OutputKernel>>>
1590SpatialConvolution(const Input &input, const Kernel &kernel, const Index row_stride = 1,
1591 const Index col_stride = 1, const PaddingType padding_type = PADDING_SAME,
1592 const Index row_in_stride = 1, const Index col_in_stride = 1,
1593 const OutputKernel &output_kernel = OutputKernel(), Index padding_top = 0,
1594 Index padding_bottom = 0, Index padding_left = 0, Index padding_right = 0)
1595{
1596 typedef typename internal::traits<Input>::Index TensorIndex;
1597 typedef typename internal::traits<Input>::Scalar InputScalar;
1598 TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions,
1599 internal::traits<Input>::Layout, TensorIndex>>
1600 in(input);
1601 TensorRef<
1602 Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions,
1603 internal::traits<Kernel>::Layout, TensorIndex>>
1604 kern(kernel);
1605
1606 EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1607 YOU_MADE_A_PROGRAMMING_MISTAKE)
1608 const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1609
1610 const int NumDims = internal::traits<Input>::NumDimensions;
1611
1612 // Number of filters to apply. This is the same as the output depth of the
1613 // result
1614 const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
1615 // Number of channels. This is the same as the input depth.
1616 const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
1617 const TensorIndex kernelRows = isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
1618 const TensorIndex kernelCols = isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
1619
1620 const Index kernelRowsEff = kernelRows + (kernelRows - 1) * (row_in_stride - 1);
1621 const Index kernelColsEff = kernelCols + (kernelCols - 1) * (col_in_stride - 1);
1622
1623 array<IndexPair<TensorIndex>, 1> contract_dims;
1624 contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1625
1626 const TensorIndex InputRows = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1627 const TensorIndex InputCols = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1628 const bool padding_explicit = (padding_top || padding_bottom || padding_left || padding_right);
1629
1630 TensorIndex out_height;
1631 TensorIndex out_width;
1632 switch (padding_type)
1633 {
1634 case PADDING_VALID:
1635 {
1636 const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
1637 const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
1638 out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride);
1639 out_width = divup(InputColsEff - kernelColsEff + 1, col_stride);
1640 break;
1641 }
1642 case PADDING_SAME:
1643 {
1644 eigen_assert(!padding_explicit);
1645 out_height = divup(InputRows, row_stride);
1646 out_width = divup(InputCols, col_stride);
1647 break;
1648 }
1649 default:
1650 {
1651 // Initialize unused variables to avoid a compiler warning
1652 out_height = 0;
1653 out_width = 0;
1654 eigen_assert(false && "unexpected padding");
1655 }
1656 }
1657
1658 // Molds the output of the patch extraction code into a 2d tensor:
1659 // - the first dimension (dims[0]): the patch values to be multiplied with the
1660 // kernels
1661 // - the second dimension (dims[1]): everything else
1662 DSizes<TensorIndex, 2> pre_contract_dims;
1663 if (isColMajor)
1664 {
1665 pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
1666 pre_contract_dims[1] = out_height * out_width;
1667 for (int i = 3; i < NumDims; ++i)
1668 {
1669 pre_contract_dims[1] *= in.dimension(i);
1670 }
1671 }
1672 else
1673 {
1674 pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
1675 pre_contract_dims[0] = out_height * out_width;
1676 for (int i = 0; i < NumDims - 3; ++i)
1677 {
1678 pre_contract_dims[0] *= in.dimension(i);
1679 }
1680 }
1681
1682 // Molds the output of the contraction into the shape expected by the used
1683 // (assuming this is ColMajor):
1684 // - 1st dim: kernel filters
1685 // - 2nd dim: output height
1686 // - 3rd dim: output width
1687 // - 4th dim and beyond: everything else including batch size
1688 DSizes<TensorIndex, NumDims> post_contract_dims;
1689 if (isColMajor)
1690 {
1691 post_contract_dims[0] = kernelFilters;
1692 post_contract_dims[1] = out_height;
1693 post_contract_dims[2] = out_width;
1694 for (int i = 3; i < NumDims; ++i)
1695 {
1696 post_contract_dims[i] = in.dimension(i);
1697 }
1698 }
1699 else
1700 {
1701 post_contract_dims[NumDims - 1] = kernelFilters;
1702 post_contract_dims[NumDims - 2] = out_height;
1703 post_contract_dims[NumDims - 3] = out_width;
1704 for (int i = 0; i < NumDims - 3; ++i)
1705 {
1706 post_contract_dims[i] = in.dimension(i);
1707 }
1708 }
1709
1710 DSizes<TensorIndex, 2> kernel_dims;
1711 if (isColMajor)
1712 {
1713 kernel_dims[0] = kernelFilters;
1714 kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
1715 }
1716 else
1717 {
1718 kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
1719 kernel_dims[1] = kernelFilters;
1720 }
1721 if (padding_explicit)
1722 {
1723 return choose(
1724 Cond<internal::traits<Input>::Layout == ColMajor>(),
1725 kernel.reshape(kernel_dims)
1726 .contract(input
1727 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1728 row_in_stride, col_in_stride,
1729 /*row_inflate_stride=*/1,
1730 /*col_inflate_stride=*/1, padding_top, padding_bottom,
1731 padding_left, padding_right,
1732 /*padding_value=*/static_cast<InputScalar>(0))
1733 .reshape(pre_contract_dims),
1734 contract_dims, output_kernel)
1735 .reshape(post_contract_dims),
1736 input
1737 .extract_image_patches(
1738 kernelRows, kernelCols, row_stride, col_stride, row_in_stride, col_in_stride,
1739 /*row_inflate_stride=*/1,
1740 /*col_inflate_stride=*/1, padding_top, padding_bottom, padding_left, padding_right,
1741 /*padding_value=*/static_cast<InputScalar>(0))
1742 .reshape(pre_contract_dims)
1743 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1744 .reshape(post_contract_dims));
1745 }
1746 else
1747 {
1748 return choose(
1749 Cond<internal::traits<Input>::Layout == ColMajor>(),
1750 kernel.reshape(kernel_dims)
1751 .contract(input
1752 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1753 row_in_stride, col_in_stride, padding_type)
1754 .reshape(pre_contract_dims),
1755 contract_dims, output_kernel)
1756 .reshape(post_contract_dims),
1757 input
1758 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, row_in_stride,
1759 col_in_stride, padding_type)
1760 .reshape(pre_contract_dims)
1761 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1762 .reshape(post_contract_dims));
1763 }
1764}
1765
1766} // end namespace Eigen
1767
1768#endif // __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
TensorContractionInputMapper< Scalar, Index, Side, TensorEvaluator< const TensorReshapingOp< NewDimension, const TensorImagePatchOp< Rows, Cols, ArgType > >, Device >, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment > Self
TensorContractionSubMapper< Scalar, Index, Side, TensorEvaluator< const TensorReshapingOp< NewDimension, const TensorImagePatchOp< Rows, Cols, ArgType > >, Device >, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment > SubMapper
EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorEvaluator< const TensorReshapingOp< NewDimension, const TensorImagePatchOp< Rows, Cols, ArgType > >, Device > &tensor, const nocontract_t &, const nocontract_t &, const contract_t &, const contract_t &)
TensorContractionInputMapper< Scalar, Index, Side, TensorEvaluator< const TensorReshapingOp< NewDimension, const TensorImagePatchOp< Rows, Cols, ArgType > >, Device >, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment > ParentMapper
TensorContractionSubMapper< Scalar, Index, Side, TensorEvaluator< const TensorReshapingOp< NewDimension, const TensorImagePatchOp< Rows, Cols, ArgType > >, Device >, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment > Self
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::enable_if< TensorEvaluatorHasPartialPacket< TensorEvaluatorT, PacketT, Index >::value, PacketT >::type partialPacketNoPadding(const Index depth, const Index baseIndex, Index num_coeffs) const
__global uchar * offset(const Image *img, int x, int y)
Definition helpers.h:540
Eigen::ThreadPoolDevice Device
Definition bias_op.h:94
TensorContractionSubMapper< Scalar, Index, Rhs, TensorEvaluator< const TensorReshapingOp< NewDimension, const TensorImagePatchOp< Rows, Cols, ArgType > >, Device >, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment > SubMapper
TensorContractionSubMapper< Scalar, Index, Rhs, TensorEvaluator< const TensorReshapingOp< NewDimension, const TensorImagePatchOp< Rows, Cols, ArgType > >, Device >, nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment > SubMapper
TensorContractionSubMapper< Scalar, Index, Rhs, TensorEvaluator< const TensorReshapingOp< NewDimension, const TensorImagePatchOp< Rows, Cols, ArgType > >, Device >, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment > SubMapper