35 const int indices_nd = indices_shape.
dims(indices_dims - 1);
39 for (
int i = 0; i < indices_dims - 1; ++i)
41 n_slices *= indices_shape.
dims(i);
47 for (
int i = indices_nd; i < params_dims; ++i)
49 slice_size *= params_shape.
dims(i);
52 int params_flat_size = params_shape.
flatSize();
53 int remain_flat_size = params_flat_size;
57 for (
int i = 0; i < indices_nd; ++i)
59 dims_to_count[i] = remain_flat_size / params_shape.
dims(i);
60 remain_flat_size = dims_to_count[i];
63 for (
int i = 0; i < n_slices; ++i)
66 for (
int j = 0; j < indices_nd; ++j)
68 int offset = i * indices_nd + j;
69 IndicesT index = index_data[
offset];
70 from_pos += index * dims_to_count[j];
72 if (from_pos < 0 || from_pos + slice_size > params_flat_size)
74 assert(
false &&
"GatherND error");
77 std::memcpy(output_data + i * slice_size, param_data + from_pos,
sizeof(ParamsT) * slice_size);
void GatherND(luci_interpreter::RuntimeShape params_shape, const ParamsT *param_data, luci_interpreter::RuntimeShape indices_shape, const IndicesT *index_data, ParamsT *output_data)