84 const Shape &filter_shape,
const uint8_t *filter_data,
85 [[maybe_unused]]
const Shape &bias_shape,
const int32_t *bias_data,
106 const uint8_t *gemm_input_data =
nullptr;
107 const Shape *gemm_input_shape =
nullptr;
108 const int filter_width = filter_shape.
Dims(2);
109 const int filter_height = filter_shape.
Dims(1);
110 const bool need_dilated_im2col = dilation_width_factor != 1 || dilation_height_factor != 1;
111 const bool need_im2col =
112 stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1;
113 if (need_dilated_im2col)
116 const int input_zero_point = -input_offset;
117 assert(input_zero_point >= 0);
118 assert(input_zero_point <= 255);
121 gemm_input_data = im2col_data;
122 gemm_input_shape = &im2col_shape;
124 else if (need_im2col)
127 const int input_zero_point = -input_offset;
128 assert(input_zero_point >= 0);
129 assert(input_zero_point <= 255);
130 Im2col(params, filter_height, filter_width, input_zero_point, input_shape, input_data,
131 im2col_shape, im2col_data);
132 gemm_input_data = im2col_data;
133 gemm_input_shape = &im2col_shape;
137 gemm_input_data = input_data;
138 gemm_input_shape = &input_shape;
141 const int gemm_input_rows = gemm_input_shape->
Dims(3);
146 const int gemm_input_cols =
147 gemm_input_shape->
Dims(0) * gemm_input_shape->
Dims(1) * gemm_input_shape->
Dims(2);
148 const int filter_rows = filter_shape.
Dims(0);
151 const int filter_cols = filter_shape.
Dims(1) * filter_shape.
Dims(2) * filter_shape.
Dims(3);
156 assert(output_rows == filter_rows);
157 assert(output_cols == gemm_input_cols);
158 assert(filter_cols == gemm_input_rows);
159 assert(bias_shape.FlatSize() == output_rows);
160 gemmlowp::MatrixMap<const uint8_t, gemmlowp::MapOrder::RowMajor> filter_matrix(
161 filter_data, filter_rows, filter_cols);
162 gemmlowp::MatrixMap<const uint8_t, gemmlowp::MapOrder::ColMajor> input_matrix(
163 gemm_input_data, gemm_input_rows, gemm_input_cols);
164 gemmlowp::MatrixMap<uint8_t, gemmlowp::MapOrder::ColMajor> output_matrix(output_data, output_rows,
166 const auto &output_pipeline =
168 output_shift, output_activation_min, output_activation_max);
171 gemmlowp::GemmWithOutputPipeline<uint8_t, uint8_t, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
172 gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, input_offset,
178namespace multithreaded
182template <
class T>
class EigenTensorConvFunctor
185 Eigen::PaddingType RuntimePadding2EigenPadding(
PaddingType padding)
190 return Eigen::PADDING_VALID;
192 return Eigen::PADDING_SAME;
195 return Eigen::PADDING_VALID;
197 return Eigen::PADDING_SAME;
202 void operator()(
const Eigen::ThreadPoolDevice &device,
const T *input_data,
int input_batches,
203 int input_height,
int input_width,
int input_depth,
const T *filter_data,
204 int filter_height,
int filter_width,
int filter_count,
int stride_rows,
206 T *output_data,
int output_height,
int output_width)
208 const bool is_1x1_kernel =
209 (filter_height == 1 && filter_width == 1 && stride_rows == 1 && stride_cols == 1);
210 const bool is_same_height_width =
211 (filter_height == input_height && filter_width == input_width && pad_width == 0 &&
213 if (is_1x1_kernel || is_same_height_width)
224 const int conv_width = output_height * output_width;
225 int io_col = input_batches;
226 int filter_col = input_depth * filter_width * filter_height;
229 io_col *= conv_width;
231 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
232 dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
247 output.device(device) = Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows,
248 RuntimePadding2EigenPadding(padding));
255 const Shape &filter_shape,
const float *filter_data,
const Shape &bias_shape,
272 const int input_depth =
MatchingDim(input_shape, 3, filter_shape, 3);
274 const int input_height = input_shape.
Dims(1);
275 const int input_width = input_shape.
Dims(2);
276 const int filter_height = filter_shape.
Dims(1);
277 const int filter_width = filter_shape.
Dims(2);
281 EigenTensorConvFunctor<float> conv_functor;
282 conv_functor(device, input_data, batches, input_height, input_width, input_depth, filter_data,
283 filter_height, filter_width, output_depth, stride_height, stride_width, pad_height,
284 pad_width, padding, output_data, output_height, output_width);