8000 Ensure conj is resolved prior to async host->device copy · pytorch/pytorch@ce451dc · GitHub
[go: up one dir, main page]

Skip to content

Commit ce451dc

Browse files
committed
Ensure conj is resolved prior to async host->device copy
Fixes #146286 ghstack-source-id: feb4d9b Pull Request resolved: #147231
1 parent e1fe699 commit ce451dc

File tree

2 files changed

+117
-24
lines changed

2 files changed

+117
-24
lines changed

aten/src/ATen/native/cuda/Copy.cu

8000
+74-23
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,24 @@ void copy_device_to_device(TensorIterator& iter,
294294
AT_CUDA_CHECK(cudaGetLastError());
295295
}
296296

297-
static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
297+
inline std::tuple<size_t, size_t, size_t, size_t> getCopyParameters(const TensorIteratorBase& iter) {
298+
size_t element_size = iter.tensor(0).element_size();
299+
if (iter.ndim() == 1) {
300+
size_t width_in_bytes = element_size;
301+
size_t src_pitch = iter.strides(1)[0];
302+
size_t dst_pitch = iter.strides(0)[0];
303+
size_t height = iter.shape()[0];
304+
return std::make_tuple(width_in_bytes, src_pitch, dst_pitch, height);
305+
} else {
306+
size_t width_in_bytes = iter.shape()[0] * element_size;
307+
size_t src_pitch = iter.strides(1)[1];
308+
size_t dst_pitch = iter.strides(0)[1];
309+
size_t height = iter.shape()[1];
310+
return std::make_tuple(width_in_bytes, src_pitch, dst_pitch, height);
311+
}
312+
}
313+
314+
static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled, bool non_blocking) {
298315
Device dst_device = iter.device(0);
299316
Device src_device = iter.device(1);
300317

@@ -303,19 +320,37 @@ static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
303320
TORCH_INTERNAL_ASSERT(dst_device.is_cuda() && src_device.is_cuda());
304321
return false;
305322
}
323+
bool different_conj_neg = iter.tensor(0).is_conj() != iter.tensor(1).is_conj() || iter.tensor(0).is_neg() != iter.tensor(1).is_neg();
324+
if(!different_conj_neg || !non_blocking){
325+
//We require a temporary when the src tensor has a conj/neg flag set, but
326+
//the destination does not, and the copy is async. Otherwise, the
327+
//conj_physical applied to dst may execute before the copy has completed.
328+
329+
bool same_dtype = iter.dtype(0) == iter.dtype(1);
330+
if (same_dtype && iter.is_contiguous()) {
331+
// Contiguous same-dtype copies can always use cudaMemcpyAsync if we don't have a conjugate flag to handle first
332+
return false;
333+
} else if (dst_device.is_cuda() && src_device.is_cuda()) {
334+
// Copies between GPUs can use the copy kernel if P2P is supported
335+
return !p2p_enabled;
336+
}
306337

307-
bool same_dtype = iter.dtype(0) == iter.dtype(1);
308-
if (same_dtype && iter.is_contiguous()) {
309-
// Contiguous same-dtype copies can always use cudaMemcpyAsync
310-
return false;
311-
} else if (dst_device.is_cuda() && src_device.is_cuda()) {
312-
// Copies between GPUs can use the copy kernel if P2P is supported
313-
return !p2p_enabled;
314-
} else {
315-
// The remaining cases require temporaries. For example, this includes
316-
// non-contiguous copies between CPU and GPU.
317-
return true;
338+
//for cross-device copies we can use memcpy2d if conditions are satisfied
339+
if (dst_device.is_cuda() != src_device.is_cuda() && same_dtype && iter.ndim() <= 2) {
340+
// TensorIterator reorders strides so that the first one is the smallest
341+
342+
if (iter.ndim() == 1 || iter.has_contiguous_first_dim()) {
343+
auto [width_in_bytes, src_pitch, dst_pitch, height] = getCopyParameters(iter);
344+
if (src_pitch >= width_in_bytes && dst_pitch >= width_in_bytes) {
345+
return false; // No need for temporaries
346+
}
347+
}
348+
}
318349
}
350+
351+
// The remaining cases require temporaries. For example, this includes
352+
// non-contiguous copies between CPU and GPU.
353+
return true;
319354
}
320355

321356
static bool maybe_enable_p2p_access(Device dst_device, Device src_device) {
@@ -333,8 +368,7 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
333368

334369
// Enable p2p access between devices. (No-op if it involves the CPU)
335370
bool p2p_enabled = maybe_enable_p2p_access(dst_device, src_device);
336-
337-
if (copy_requires_temporaries(iter, p2p_enabled)) {
371+
if (copy_requires_temporaries(iter, p2p_enabled, non_blocking)) {
338372
// NB: this involves recursive calls to copy. Be careful that those copies
339373
// don't require temporaries or you will cause an infinite recursion!
340374
auto& dst = iter.tensor(0);
@@ -355,16 +389,31 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
355389
src_contig = iter.tensor(1).expand_as(dst).contiguous();
356390
}
357391

358-
// propagate the correct conjugate bit
359-
dst_contig._set_conj(dst.is_conj());
360-
src_contig._set_conj(iter.tensor(1).is_conj());
392+
if (non_blocking){
393+
const bool need_conj = dst.is_conj() != src_contig.is_conj();
394+
const bool need_neg = dst.is_neg() != src_contig.is_neg();
395+
// Doing these inplace may lead to modification on the src if none of
396+
// the above materialized a clone.
397+
if(need_conj){
398+
src_contig = src_contig.conj_physical();
399+
}
400+
if(need_neg){
401+
src_contig = src_contig.neg();
402+
}
403+
src_contig._set_conj(dst.is_conj());
404+
src_contig._set_neg(dst.is_neg());
405+
}
406+
361407

408+
// propagate the correct conjugate bit to dst
409+
dst_contig._set_conj(dst.is_conj());
362410
dst_contig._set_neg(dst.is_neg());
363-
src_contig._set_neg(iter.tensor(1).is_neg());
364411

365412
// perform a same-dtype copy on contiguous tensors
366413
TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));
367414
TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type());
415+
TORCH_INTERNAL_ASSERT(dst_contig.is_conj() == src_contig.is_conj());
416+
TORCH_INTERNAL_ASSERT(dst_contig.is_neg() == src_contig.is_neg());
368417
dst_contig.copy_(src_contig, non_blocking);
369418

370419
// if necessary, copy back into dst
@@ -424,11 +473,13 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
424473
at::cuda::memcpy_and_sync(dst, src, nbytes, kind, stream);
425474
}
426475

427-
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
428-
iter.tensor(0).conj_physical_();
429-
}
430-
if (iter.tensor(0).is_neg() != iter.tensor(1).is_neg()) {
431-
iter.tensor(0).neg_();
476+
if(!non_blocking){
477+
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
478+
iter.tensor(0).conj_physical_();
479+
}
480+
if (iter.tensor(0).is_neg() != iter.tensor(1).is_neg()) {
481+
iter.tensor(0).neg_();
482+
}
432483
}
433484
}
434485

test/test_complex.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,15 @@
66
dtypes,
77
instantiate_device_type_tests,
88
onlyCPU,
9+
skipIf,
910
)
1011
from torch.testing._internal.common_dtype import complex_types
11-
from torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase
12+
from torch.testing._internal.common_utils import (
13+
parametrize,
14+
run_tests,
15+
set_default_dtype,
16+
TestCase,
17+
)
1218

1319

1420
devices = (torch.device("cpu"), torch.device("cuda:0"))
@@ -44,6 +50,42 @@ def test_conj_copy(self, device, dtype):
4450
x1.copy_(xc1)
4551
self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype))
4652

53+
@skipIf(not torch.cuda.is_available(), "Test only applies to CUDA enabled builds")
54+
@dtypes(*complex_types())
55+
@parametrize(
56+
"src_conj,dst_conj",
57+
[(True, True), (False, True), (True, False), (False, False)],
58+
)
59+
def test_conj_copy_async(self, device, dtype, src_conj, dst_conj):
60+
# issue: https://github.com/pytorch/pytorch/issues/146286
61+
def pin(d):
62+
return d == "cpu"
63+
64+
src = torch.tensor(
65+
[5 + 1j, 2 + 2j], device=device, dtype=dtype, pin_memory=pin(device)
66+
)
67+
src_block = src.clone()
68+
if src_conj:
69+
src = src.conj()
70+
src_block = src_block.conj()
71+
72+
src_ref = src.clone()
73+
74+
# The copy is cross device so parameterize on the source device and make sure the dst is the other one
75+
dst_device = "cuda:0" if device == "cpu" else "cpu"
76+
dst = torch.zeros_like(src, device=dst_device, pin_memory=pin(dst_device))
77+
dst_block = torch.zeros_like(src, device=dst_device, pin_memory=pin(dst_device))
78+
if dst_conj:
79+
dst = dst.conj()
80+
dst_block = dst_block.conj()
81+
82+
dst.copy_(src, non_blocking=True)
83+
dst_block.copy_(src_block, non_blocking=False)
84+
85+
self.assertTrue(dst.is_conj() == dst_conj)
86+
self.assertEqual(dst_block, dst)
87+
self.assertEqual(src, src_ref)
88+
4789
@dtypes(*complex_types())
4890
def test_all(self, device, dtype):
4991
# issue: https://github.com/pytorch/pytorch/issues/120875

0 commit comments

Comments
 (0)
0