39 Conv() : _im2col_shape(4), _need_im2col(false), _prepared(false) {}
42 uint32_t stride_width, uint32_t stride_height, uint32_t dilation_width_factor,
43 uint32_t dilation_height_factor)
47 IsRequiredIm2col(input_shape, kernel_shape,
output_shape, stride_width, stride_height,
48 dilation_width_factor, dilation_height_factor);
54 const Shape &filter_shape,
const float *filter_data,
const Shape &bias_shape,
56 ::ruy::Context *ruy_context)
67 int im2col_size = _need_im2col ? _im2col_shape.
FlatSize() : 0;
68 std::vector<float> im2col_data(im2col_size);
71 ConvFloat(params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data,
72 output_shape, output_data, _im2col_shape, im2col_data.data(), ruy_context);
76 ConvFloat(params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data,
77 output_shape, output_data, _im2col_shape,
nullptr, ruy_context);
82 void ConvFloat(
const ConvParams ¶ms,
const Shape &input_shape,
const float *input_data,
83 const Shape &filter_shape,
const float *filter_data,
84 [[maybe_unused]]
const Shape &bias_shape,
const float *bias_data,
86 float *im2col_data, ::ruy::Context *ruy_context)
99 const uint8_t float_zero_byte = 0x00;
100 const float *gemm_input_data =
nullptr;
101 const Shape *gemm_input_shape =
nullptr;
102 const int filter_width = filter_shape.
Dims(2);
103 const int filter_height = filter_shape.
Dims(1);
104 const bool need_dilated_im2col = dilation_width_factor != 1 || dilation_height_factor != 1;
105 const bool need_im2col =
106 stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1;
107 if (need_dilated_im2col)
111 gemm_input_data = im2col_data;
112 gemm_input_shape = &im2col_shape;
114 else if (need_im2col)
117 Im2col(params, filter_height, filter_width, float_zero_byte, input_shape, input_data,
118 im2col_shape, im2col_data);
119 gemm_input_data = im2col_data;
120 gemm_input_shape = &im2col_shape;
126 assert(!im2col_data);
127 gemm_input_data = input_data;
128 gemm_input_shape = &input_shape;
134 int k = gemm_input_shape->
Dims(gemm_input_dims - 1);
138 MatrixParams<float> lhs_params;
142 MatrixParams<float> rhs_params;
146 MatrixParams<float> dst_params;
150 GemmParams<float, float> gemm_params;
151 gemm_params.bias = bias_data;
152 gemm_params.clamp_min = output_activation_min;
153 gemm_params.clamp_max = output_activation_max;
156 ::ruy::Matrix<float> ruy_lhs;
157 ::ruy::Matrix<float> ruy_rhs;
158 ::ruy::Matrix<float> ruy_dst;
164 ::ruy::MulParams<float, float> ruy_mul_params;
167 ::ruy::Mul(ruy_lhs, ruy_rhs, ruy_mul_params, ruy_context, &ruy_dst);
170 void IsRequiredIm2col(
const Shape &input_shape,
const Shape &kernel_shape,
172 uint32_t dilation_width_factor, uint32_t dilation_height_factor)
174 const bool need_dilated_im2col = dilation_width_factor != 1 || dilation_height_factor != 1;
175 const bool need_non_dilated_im2col = stride_width != 1 || stride_height != 1 ||
176 kernel_shape.Dims(1) != 1 || kernel_shape.Dims(2) != 1;
178 _need_im2col = need_dilated_im2col || need_non_dilated_im2col;
185 _im2col_shape.
SetDim(3, input_shape.Dims(3) * kernel_shape.Dims(1) * kernel_shape.Dims(2));