18#ifndef __NNFW_CKER_TRAIN_OPERATION_CONV_H__
19#define __NNFW_CKER_TRAIN_OPERATION_CONV_H__
44 Eigen::DenseIndex row_stride, Eigen::DenseIndex col_dilation,
45 Eigen::DenseIndex row_dilation)
47 input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
48 filter, output_backward, input_backward.dimension(2), input_backward.dimension(1), col_stride,
49 row_stride, col_dilation, row_dilation);
58 Eigen::DenseIndex padded_rows, Eigen::DenseIndex col_stride,
59 Eigen::DenseIndex row_stride, Eigen::DenseIndex col_dilation,
60 Eigen::DenseIndex row_dilation, Eigen::DenseIndex pad_left,
61 Eigen::DenseIndex pad_top)
68 input_backward.device(d) =
69 Eigen::SpatialConvolutionBackwardInput(filter, output_backward, padded_cols, padded_rows,
70 col_stride, row_stride, col_dilation, row_dilation)
72 .slice(Eigen::DSizes<Eigen::DenseIndex, 4>{0, pad_left, pad_top, 0},
73 input_backward.dimensions());
84 void operator()(
const Device &d,
const T *out_backprop_data,
int batches,
int out_backprop_height,
85 int out_backprop_width,
int output_depth,
const T *filter_data,
int filter_height,
86 int filter_width,
int row_dilation,
int col_dilation,
int row_stride ,
87 int col_stride ,
const PaddingType &padding_type,
int padding_top,
88 int padding_bottom,
int padding_left,
int padding_right, T *in_backprop_data,
89 int in_backprop_height,
int in_backprop_width,
int input_depth)
96 in_backprop_width, input_depth);
100 out_backprop_width, output_depth);
107 d, in_backprop_t, filter_t, out_backprop_t, col_stride, row_stride, col_dilation,
113 d, in_backprop_t, filter_t, out_backprop_t,
114 in_backprop_t.dimension(2) + (padding_left + padding_right),
115 in_backprop_t.dimension(1) + (padding_top + padding_bottom), col_stride, row_stride,
116 col_dilation, row_dilation, padding_top, padding_left);
124 void operator()(
const T *out_backprop_data,
int batches,
int out_backprop_height,
125 int out_backprop_width,
int output_depth,
const T *filter_data,
int filter_height,
126 int filter_width,
int row_dilation,
int col_dilation,
int row_stride ,
127 int col_stride ,
const PaddingType &padding_type,
int padding_top,
128 int padding_bottom,
int padding_left,
int padding_right, T *in_backprop_data,
129 int in_backprop_height,
int in_backprop_width,
int input_depth)
133 out_backprop_width, output_depth, filter_data, filter_height, filter_width,
134 row_dilation, col_dilation, row_stride, col_stride, padding_type, padding_top,
135 padding_bottom, padding_left, padding_right, in_backprop_data, in_backprop_height,
136 in_backprop_width, input_depth);
143 void operator()(
const T *out_backprop_data,
int batches,
int out_backprop_height,
144 int out_backprop_width,
int output_depth,
const T *input_data,
int input_height,
145 int input_width,
int input_depth,
int row_dilation,
int col_dilation,
146 int row_stride ,
int col_stride ,
const PaddingType &padding_type,
147 int padding_top,
int padding_bottom,
int padding_left,
int padding_right,
148 T *filter_backprop_data,
int filter_backprop_height,
int filter_backprop_width)
151 filter_backprop_width, input_depth, output_depth);
155 out_backprop_width, output_depth);
163 filter_backprop_t.device(d) = Eigen::SpatialConvolutionBackwardKernel(
164 input_t, out_backprop_t, filter_backprop_t.dimension(1), filter_backprop_t.dimension(0),
165 col_stride, row_stride, col_dilation, row_dilation);
171 Eigen::array<std::pair<int, int>, 4> paddings;
172 paddings[0] = {0, 0};
173 paddings[1] = {padding_top, padding_bottom};
174 paddings[2] = {padding_left, padding_right};
175 paddings[3] = {0, 0};
177 auto padded_t = input_t.pad(paddings, T(0));
181 filter_backprop_t.device(d) = Eigen::SpatialConvolutionBackwardKernel(
182 padded_t, out_backprop_t, filter_backprop_t.dimension(1), filter_backprop_t.dimension(0),
183 col_stride, row_stride, col_dilation, row_dilation);
189 const float *incoming_data,
const Shape &filter_shape,
190 const float *filter_data,
const int padding_bottom,
191 const int padding_right,
const Shape &grad_shape,
float *grad_data)
198 assert(padding_top >= 0);
199 assert(padding_bottom >= 0);
200 assert(padding_left >= 0);
201 assert(padding_right >= 0);
205 const int batches =
MatchingDim(grad_shape, 0, incoming_shape, 0);
206 const int input_depth =
MatchingDim(filter_shape, 2, grad_shape, 3);
207 const int output_depth =
MatchingDim(filter_shape, 3, incoming_shape, 3);
208 const int grad_height = grad_shape.
Dims(1);
209 const int grad_width = grad_shape.
Dims(2);
210 const int filter_height = filter_shape.
Dims(0);
211 const int filter_width = filter_shape.
Dims(1);
212 const int incoming_height = incoming_shape.
Dims(1);
213 const int incoming_width = incoming_shape.
Dims(2);
215 if (dilation_rows != 1 || dilation_cols != 1)
216 throw std::runtime_error(
"cker::ConvFilterGrad: not yet support dilation rates larger than 1.");
219 incoming_data, batches, incoming_height, incoming_width, output_depth, filter_data,
220 filter_height, filter_width, dilation_rows, dilation_cols, stride_rows, stride_cols, padding,
221 padding_top, padding_bottom, padding_left, padding_right, grad_data, grad_height, grad_width,
226 const float *incoming_data,
const Shape &input_shape,
227 const float *input_data,
const int padding_bottom,
228 const int padding_right,
const Shape &filter_backprop_shape,
229 float *filter_backprop_data)
236 assert(padding_top >= 0);
237 assert(padding_bottom >= 0);
238 assert(padding_left >= 0);
239 assert(padding_right >= 0);
243 const int batches =
MatchingDim(input_shape, 0, incoming_shape, 0);
244 const int input_depth =
MatchingDim(filter_backprop_shape, 2, input_shape, 3);
245 const int output_depth =
MatchingDim(filter_backprop_shape, 3, incoming_shape, 3);
246 const int input_height = input_shape.
Dims(1);
247 const int input_width = input_shape.
Dims(2);
248 const int filter_backprop_height = filter_backprop_shape.
Dims(0);
249 const int filter_backprop_width = filter_backprop_shape.
Dims(1);
250 const int incoming_height = incoming_shape.
Dims(1);
251 const int incoming_width = incoming_shape.
Dims(2);
253 if (dilation_rows != 1 || dilation_cols != 1)
254 throw std::runtime_error(
"cker::ConvFilterGrad: not yet support dilation rates larger than 1.");
257 incoming_data, batches, incoming_height, incoming_width, output_depth, input_data, input_height,
258 input_width, input_depth, dilation_rows, dilation_cols, stride_rows, stride_cols, padding,
259 padding_top, padding_bottom, padding_left, padding_right, filter_backprop_data,
260 filter_backprop_height, filter_backprop_width);
int32_t Dims(int i) const
Eigen::TensorMap< Eigen::Tensor< float, 4, Eigen::RowMajor, Eigen::DenseIndex >, Eigen::Aligned > EigenTensor
const Eigen::ThreadPoolDevice * GetThreadPoolDevice()
Eigen::TensorMap< Eigen::Tensor< const float, 4, Eigen::RowMajor, Eigen::DenseIndex >, Eigen::Aligned > ConstEigenTensor
void ConvFilterGrad(const ConvParams ¶ms, const Shape &incoming_shape, const float *incoming_data, const Shape &input_shape, const float *input_data, const int padding_bottom, const int padding_right, const Shape &filter_backprop_shape, float *filter_backprop_data)
void ConvInputGrad(const ConvParams ¶ms, const Shape &incoming_shape, const float *incoming_data, const Shape &filter_shape, const float *filter_data, const int padding_bottom, const int padding_right, const Shape &grad_shape, float *grad_data)
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
PaddingValues padding_values
int16_t dilation_width_factor
int16_t dilation_height_factor
Eigen::TensorMap< Eigen::Tensor< const T, NDIMS, Eigen::RowMajor, IndexType >, Eigen::Aligned > ConstTensor
Eigen::TensorMap< Eigen::Tensor< T, NDIMS, Eigen::RowMajor, IndexType >, Eigen::Aligned > Tensor
void operator()(const T *out_backprop_data, int batches, int out_backprop_height, int out_backprop_width, int output_depth, const T *input_data, int input_height, int input_width, int input_depth, int row_dilation, int col_dilation, int row_stride, int col_stride, const PaddingType &padding_type, int padding_top, int padding_bottom, int padding_left, int padding_right, T *filter_backprop_data, int filter_backprop_height, int filter_backprop_width)