40 const int indices_nd = indices_shape.
dims(indices_dims - 1);
44 for (
int i = 0; i < indices_dims - 1; ++i)
46 n_slices *= indices_shape.
dims(i);
52 for (
int i = indices_nd; i < params_dims; ++i)
54 slice_size *= params_shape.
dims(i);
57 int params_flat_size = params_shape.
flatSize();
58 int remain_flat_size = params_flat_size;
62 for (
int i = 0; i < indices_nd; ++i)
64 dims_to_count[i] = remain_flat_size / params_shape.
dims(i);
65 remain_flat_size = dims_to_count[i];
68 for (
int i = 0; i < n_slices; ++i)
71 for (
int j = 0; j < indices_nd; ++j)
73 int offset = i * indices_nd + j;
74 IndicesT index = index_data[
offset];
75 from_pos += index * dims_to_count[j];
77 if (from_pos < 0 || from_pos + slice_size > params_flat_size)
79 assert(
false &&
"GatherND error");
82 std::memcpy(output_data + i * slice_size, param_data + from_pos,
sizeof(ParamsT) * slice_size);