37 const auto &data_shape = datav.
getShape();
38 const auto &indices_shape = indicesv.
getShape();
45 axis += data_shape.rank();
46 assert(axis >= 0 && axis < data_shape.rank());
47 int32_t axis_size = data_shape.dim(axis);
48 int32_t num_indices = indices_shape.numElements();
50 int32_t outer_size = 1;
51 for (int32_t i = 0; i < axis; ++i)
52 outer_size *= data_shape.dim(i);
54 int32_t inner_size = 1;
55 for (int32_t i = axis + 1; i < data_shape.rank(); ++i)
56 inner_size *= data_shape.dim(i);
58 for (int32_t outer = 0; outer < outer_size; ++outer)
60 for (int32_t i = 0; i < num_indices; ++i)
63 assert(index >= 0 && index < axis_size);
64 for (int32_t inner = 0; inner < inner_size; inner++)
66 output.atOffset((outer * num_indices + i) * inner_size + inner) =
67 data.atOffset((outer * axis_size + index) * inner_size + inner);