8000 Ensure conj is resolved prior to async host->device copy · pytorch/pytorch@f7663c5 · GitHub
[go: up one dir, main page]

Skip to content

Commit f7663c5

Browse files
committed
Ensure conj is resolved prior to async host->device copy
Fixes #146286 ghstack-source-id: 0c44ccf Pull Request resolved: #147231
1 parent 0c8028e commit f7663c5

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ inline std::tuple<size_t, size_t, size_t, size_t> getCopyParameters(const Tensor
289289
}
290290
}
291291

292-
static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
292+
static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled, bool non_blocking) {
293293
Device dst_device = iter.device(0);
294294
Device src_device = iter.device(1);
295295

@@ -300,8 +300,9 @@ static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
300300
}
301301

302302
bool same_dtype = iter.dtype(0) == iter.dtype(1);
303-
if (same_dtype && iter.is_contiguous()) {
304-
// Contiguous same-dtype copies can always use cudaMemcpyAsync
303+
bool same_conj_neg = iter.tensor(0).is_conj() == iter.tensor(1).is_conj() && iter.tensor(0).is_neg() == iter.tensor(0).is_neg();
304+
if (same_dtype && iter.is_contiguous() && same_conj_neg) {
305+
// Contiguous same-dtype copies can always use cudaMemcpyAsync if we don't have a conjugate flag to handle first
305306
return false;
306307
} else if (dst_device.is_cuda() && src_device.is_cuda()) {
307308
// Copies between GPUs can use the copy kernel if P2P is supported
@@ -310,12 +311,15 @@ static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
310311

311312
//for cross-device copies we can use memcpy2d if conditions are satisfied
312313
if (dst_device.is_cuda() != src_device.is_cuda() && same_dtype && iter.ndim() <= 2) {
313-
// TensorIterator reorders strides so that the first one is the smallest
314-
315-
if (iter.ndim() == 1 || iter.has_contiguous_first_dim()) {
316-
auto [width_in_bytes, src_pitch, dst_pitch, height] = getCopyParameters(iter);
317-
if (src_pitch >= width_in_bytes && dst_pitch >= width_in_bytes) {
318-
return false; // No need for temporaries
314+
//We can do an async copy if there is not a conjugate bit to resolve
315+
if(same_conj_neg || !non_blocking){
316+
// TensorIterator reorders strides so that the first one is the smallest
317+
318+
if (iter.ndim() == 1 || iter.has_contiguous_first_dim()) {
319+
auto [width_in_bytes, src_pitch, dst_pitch, height] = getCopyParameters(iter);
320+
if (src_pitch >= width_in_bytes && dst_pitch >= width_in_bytes) {
321+
return false; // No need for temporaries
322+
}
319323
}
320324
}
321325
}
@@ -340,8 +344,8 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
340344

341345
// Enable p2p access between devices. (No-op if it involves the CPU)
342346
bool p2p_enabled = maybe_enable_p2p_access(dst_device, src_device);
343-
344-
if (copy_requires_temporaries(iter, p2p_enabled)) {
347+
bool temp_needed = copy_requires_temporaries(iter, p2p_enabled, non_blocking);
348+
if (copy_requires_temporaries(iter, p2p_enabled, non_blocking)) {
345349
// NB: this involves recursive calls to copy. Be careful that those copies
346350
// don't require temporaries or you will cause an infinite recursion!
347351
auto& dst = iter.tensor(0);
@@ -355,19 +359,17 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
355359
auto conversion_device = non_blocking ? kCUDA : kCPU;
356360
if (iter.device_type(1) == conversion_device) {
357361
dst_contig = dst.is_contiguous() ? dst : at::empty_like(dst, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
358-
src_contig = iter.tensor(1).to(iter.dtype(0)).expand_as(dst).contiguous();
362+
src_contig = iter.tensor(1).to(iter.dtype(0)).expand_as(dst).contiguous().resolve_conj();
359363
} else {
360364
bool same_type = iter.dtype(0) == iter.dtype(1);
361365
dst_contig = (dst.is_contiguous() && same_type) ? dst : at::empty_like(dst, iter.dtype(1), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
362-
src_contig = iter.tensor(1).expand_as(dst).contiguous();
366+
src_contig = iter.tensor(1).expand_as(dst).contiguous().resolve_conj();
363367
}
364368

365-
// propagate the correct conjugate bit
369+
// propagate the correct conjugate bit to dst, src is resolved above
366370
dst_contig._set_conj(dst.is_conj());
367-
src_contig._set_conj(iter.tensor(1).is_conj());
368371

369372
dst_contig._set_neg(dst.is_neg());
370-
src_contig._set_neg(iter.tensor(1).is_neg());
371373

372374
// perform a same-dtype copy on contiguous tensors
373375
TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));

test/test_complex.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
dtypes,
77
instantiate_device_type_tests,
88
onlyCPU,
9+
onlyCUDA,
910
)
1011
from torch.testing._internal.common_dtype import complex_types
1112
from torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase
@@ -44,6 +45,15 @@ def test_conj_copy(self, device, dtype):
4445
x1.copy_(xc1)
4546
self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype))
4647

48+
@onlyCUDA
49+
@dtypes(*complex_types())
50+
def test_conj_copy_async_h2d(self, device, dtype):
51+
# issue: https://github.com/pytorch/pytorch/issues/146286
52+
x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype).conj()
53+
x2 = torch.zeros_like(x1, device="cpu").pin_memory()
54+
x2.copy_(x1, non_blocking=True)
55+
self.assertEqual(x1, x2)
56+
4757
@dtypes(*complex_types())
4858
def test_all(self, device, dtype):
4959
# issue: https://github.com/pytorch/pytorch/issues/120875

0 commit comments

Comments
 (0)
0