64 {
65
66 if (input_shape == update_shape)
67 {
68 memcpy(output_data, update_data, update_shape.FlatSize() * sizeof(T));
69 return;
70 }
71
72
73 if (input_data != output_data)
74 memcpy(output_data, input_data, input_shape.FlatSize() * sizeof(T));
75
76
77 if (update_shape.FlatSize() == 0)
78 return;
79
80
81 const auto input_dims = input_shape.DimensionsCount();
82 std::vector<int64_t> clamped_start_indices(input_dims, 0);
83 assert(input_dims == update_shape.DimensionsCount());
84 for (int i = 0; i < input_dims; i++)
85 {
86 clamped_start_indices[i] = std::min<int64_t>(std::max<int64_t>(0, indices_data[i]),
87 input_shape.Dims(i) - update_shape.Dims(i));
88 }
89
90
91 std::vector<int32_t> output_stride(input_dims);
92 std::vector<int32_t> update_stride(input_dims);
93 output_stride[input_dims - 1] = 1;
94 update_stride[input_dims - 1] = 1;
95 for (int i = input_dims - 2; i >= 0; --i)
96 {
97 output_stride[i] = output_stride[i + 1] * input_shape.Dims(i + 1);
98 update_stride[i] = update_stride[i + 1] * update_shape.Dims(i + 1);
99 }
100
101 UpdateSlice<T>(0, input_dims, output_stride, update_stride, update_shape, update_data,
102 clamped_start_indices, output_data);
103 }