36 void UpdateSlice(int32_t current_dim, int32_t max_dim,
const std::vector<int32_t> &output_stride,
53 UpdateSlice(current_dim + 1, max_dim, output_stride, update_stride, update_shape, update,
63 const T *update_data,
const std::vector<int64_t> &indices_data, T *output_data)
66 if (input_shape == update_shape)
68 memcpy(output_data, update_data, update_shape.
FlatSize() *
sizeof(T));
73 if (input_data != output_data)
74 memcpy(output_data, input_data, input_shape.
FlatSize() *
sizeof(T));
82 std::vector<int64_t> clamped_start_indices(input_dims, 0);
84 for (
int i = 0; i < input_dims; i++)
86 clamped_start_indices[i] = std::min<int64_t>(std::max<int64_t>(0, indices_data[i]),
87 input_shape.
Dims(i) - update_shape.
Dims(i));
91 std::vector<int32_t> output_stride(input_dims);
92 std::vector<int32_t> update_stride(input_dims);
93 output_stride[input_dims - 1] = 1;
94 update_stride[input_dims - 1] = 1;
95 for (
int i = input_dims - 2; i >= 0; --i)
97 output_stride[i] = output_stride[i + 1] * input_shape.
Dims(i + 1);
98 update_stride[i] = update_stride[i + 1] * update_shape.
Dims(i + 1);
101 UpdateSlice<T>(0, input_dims, output_stride, update_stride, update_shape, update_data,
102 clamped_start_indices, output_data);