51 const int32_t *block_shape_data,
const int32_t *crops_data,
52 const Shape &unextended_output_shape, T *output_data)
57 assert(input_dim == 3 || input_dim == 4);
58 assert(input_dim == output_dim);
64 auto extend_shape = [](
const Shape &shape) {
65 if (shape.DimensionsCount() == 4)
69 Shape new_shape(4, 1);
70 new_shape.
SetDim(0, shape.Dims(0));
71 new_shape.
SetDim(1, shape.Dims(1));
72 new_shape.
SetDim(3, shape.Dims(2));
75 const Shape input1_shape = extend_shape(unextended_input1_shape);
82 const int32_t depth = input1_shape.
Dims(3);
83 const int32_t input_width = input1_shape.
Dims(2);
84 const int32_t input_height = input1_shape.
Dims(1);
85 const int32_t input_batch_size = input1_shape.
Dims(0);
87 const int32_t block_shape_height = block_shape_data[0];
88 const int32_t block_shape_width = block_shape_data[1];
90 const int32_t crops_top = crops_data[0];
91 const int32_t crops_left = crops_data[2];
93 for (
int in_batch = 0; in_batch < input_batch_size; ++in_batch)
95 const int out_batch = in_batch % output_batch_size;
96 const int spatial_offset = in_batch / output_batch_size;
101 GetIndexRange(spatial_offset / block_shape_width - crops_top, block_shape_height, input_height,
102 output_height, &in_h_start, &in_h_end);
104 for (
int in_h = in_h_start; in_h < in_h_end; ++in_h)
106 const int out_h = in_h * block_shape_height + spatial_offset / block_shape_width - crops_top;
108 assert(out_h < output_height);
113 GetIndexRange(spatial_offset % block_shape_width - crops_left, block_shape_width, input_width,
114 output_width, &in_w_start, &in_w_end);
116 for (
int in_w = in_w_start; in_w < in_w_end; ++in_w)
119 in_w * block_shape_width + spatial_offset % block_shape_width - crops_left;
121 assert(out_w < output_width);
123 const T *in = input1_data +
Offset(input1_shape, in_batch, in_h, in_w, 0);
124 memcpy(out, in, depth *
sizeof(T));
void BatchToSpaceND(const Shape &unextended_input1_shape, const T *input1_data, const int32_t *block_shape_data, const int32_t *crops_data, const Shape &unextended_output_shape, T *output_data)