34 int axis = params.
axis < 0 ? params.
axis + split_dimensions : params.
axis;
37 int64_t outer_size = 1;
38 for (
int i = 0; i < axis; ++i)
40 outer_size *= input_shape.
Dims(i);
44 int64_t base_inner_size = 1;
45 for (
int i = axis + 1; i < split_dimensions; ++i)
47 base_inner_size *= input_shape.
Dims(i);
50 const Scalar *input_ptr = input_data;
51 for (
int k = 0; k < outer_size; k++)
53 for (
int i = 0; i < outputs_count; ++i)
55 const int copy_size =
output_shape.Dims(axis) * base_inner_size;
56 memcpy(output_data[i] + k * copy_size, input_ptr, copy_size *
sizeof(Scalar));
57 input_ptr += copy_size;
void Split(const SplitParams ¶ms, const Shape &input_shape, const Scalar *input_data, const Shape &output_shape, Scalar *const *output_data)