8000 Ensure conj/neg flags are set in destination for CUDA->CPU copies by amjames · Pull Request #147231 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Ensure conj/neg flags are set in destination for CUDA->CPU copies #147231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: gh/amjames/20/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 74 additions & 23 deletions aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,24 @@ void copy_device_to_device(TensorIterator& iter,
AT_CUDA_CHECK(cudaGetLastError());
}

static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
inline std::tuple<size_t, size_t, size_t, size_t> getCopyParameters(const TensorIteratorBase& iter) {
size_t element_size = iter.tensor(0).element_size();
if (iter.ndim() == 1) {
size_t width_in_bytes = element_size;
size_t src_pitch = iter.strides(1)[0];
size_t dst_pitch = iter.strides(0)[0];
size_t height = iter.shape()[0];
return std::make_tuple(width_in_bytes, src_pitch, dst_pitch, height);
} else {
size_t width_in_bytes = iter.shape()[0] * element_size;
size_t src_pitch = iter.strides(1)[1];
size_t dst_pitch = iter.strides(0)[1];
size_t height = iter.shape()[1];
return std::make_tuple(width_in_bytes, src_pitch, dst_pitch, height);
}
}

static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled, bool non_blocking) {
Device dst_device = iter.device(0);
Device src_device = iter.device(1);

Expand All @@ -303,19 +320,37 @@ static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
TORCH_INTERNAL_ASSERT(dst_device.is_cuda() && src_device.is_cuda());
return false;
}
bool different_conj_neg = iter.tensor(0).is_conj() != iter.tensor(1).is_conj() || iter.tensor(0).is_neg() != iter.tensor(1).is_neg();
if(!different_conj_neg || !non_blocking){
//We require a temporary when the src tensor has a conj/neg flag set, but
//the destination does not, and the copy is async. Otherwise, the
//conj_physical applied to dst may execute before the copy has completed.

bool same_dtype = iter.dtype(0) == iter.dtype(1);
if (same_dtype && iter.is_contiguous()) {
// Contiguous same-dtype copies can always use cudaMemcpyAsync if we don't have a conjugate flag to handle first
return false;
} else if (dst_device.is_cuda() && src_device.is_cuda()) {
// Copies between GPUs can use the copy kernel if P2P is supported
return !p2p_enabled;
}

bool same_dtype = iter.dtype(0) == iter.dtype(1);
if (same_dtype && iter.is_contiguous()) {
// Contiguous same-dtype copies can always use cudaMemcpyAsync
return false;
} else if (dst_device.is_cuda() && src_device.is_cuda()) {
// Copies between GPUs can use the copy kernel if P2P is supported
return !p2p_enabled;
} else {
// The remaining cases require temporaries. For example, this includes
// non-contiguous copies between CPU and GPU.
return true;
//for cross-device copies we can use memcpy2d if conditions are satisfied
if (dst_device.is_cuda() != src_device.is_cuda() && same_dtype && iter.ndim() <= 2) {
// TensorIterator reorders strides so that the first one is the smallest

if (iter.ndim() == 1 || iter.has_contiguous_first_dim()) {
auto [width_in_bytes, src_pitch, dst_pitch, height] = getCopyParameters(iter);
if (src_pitch >= width_in_bytes && dst_pitch >= width_in_bytes) {
return false; // No need for temporaries
}
}
}
}

// The remaining cases require temporaries. For example, this includes
// non-contiguous copies between CPU and GPU.
return true;
}

static bool maybe_enable_p2p_access(Device dst_device, Device src_device) {
Expand All @@ -333,8 +368,7 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {

// Enable p2p access between devices. (No-op if it involves the CPU)
bool p2p_enabled = maybe_enable_p2p_access(dst_device, src_device);

if (copy_requires_temporaries(iter, p2p_enabled)) {
if (copy_requires_temporaries(iter, p2p_enabled, non_blocking)) {
// NB: this involves recursive calls to copy. Be careful that those copies
// don't require temporaries or you will cause an infinite recursion!
auto& dst = iter.tensor(0);
Expand All @@ -355,16 +389,31 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
src_contig = iter.tensor(1).expand_as(dst).contiguous();
}

// propagate the correct conjugate bit
dst_contig._set_conj(dst.is_conj());
src_contig._set_conj(iter.tensor(1).is_conj());
if (non_blocking){
const bool need_conj = dst.is_conj() != src_contig.is_conj();
const bool need_neg = dst.is_neg() != src_contig.is_neg();
// Doing these inplace may lead to modification on the src if none of
// the above materialized a clone.
if(need_conj){
src_contig = src_contig.conj_physical();
}
if(need_neg){
src_contig = src_contig.neg();
}
src_contig._set_conj(dst.is_conj());
src_contig._set_neg(dst.is_neg());
}


// propagate the correct conjugate bit to dst
dst_contig._set_conj(dst.is_conj());
dst_contig._set_neg(dst.is_neg());
src_contig._set_neg(iter.tensor(1).is_neg());

// perform a same-dtype copy on contiguous tensors
TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));
TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type());
TORCH_INTERNAL_ASSERT(dst_contig.is_conj() == src_contig.is_conj());
TORCH_INTERNAL_ASSERT(dst_contig.is_neg() == src_contig.is_neg());
dst_contig.copy_(src_contig, non_blocking);

// if necessary, copy back into dst
Expand Down Expand Up @@ -424,11 +473,13 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
at::cuda::memcpy_and_sync(dst, src, nbytes, kind, stream);
}

if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
iter.tensor(0).conj_physical_();
}
if (iter.tensor(0).is_neg() != iter.tensor(1).is_neg()) {
iter.tensor(0).neg_();
if(!non_blocking){
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
iter.tensor(0).conj_physical_();
}
if (iter.tensor(0).is_neg() != iter.tensor(1).is_neg()) {
iter.tensor(0).neg_();
}
}
}

Expand Down
44 changes: 43 additions & 1 deletion test/test_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
dtypes,
instantiate_device_type_tests,
onlyCPU,
skipIf,
)
from torch.testing._internal.common_dtype import complex_types
from torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
set_default_dtype,
TestCase,
)


devices = (torch.device("cpu"), torch.device("cuda:0"))
Expand Down Expand Up @@ -44,6 +50,42 @@ def test_conj_copy(self, device, dtype):
x1.copy_(xc1)
self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype))

@skipIf(not torch.cuda.is_available(), "Test only applies to CUDA enabled builds")
@dtypes(*complex_types())
@parametrize(
"src_conj,dst_conj",
[(True, True), (False, True), (True, False), (False, False)],
)
def test_conj_copy_async(self, device, dtype, src_conj, dst_conj):
# issue: https://github.com/pytorch/pytorch/issues/146286
def pin(d):
return d == "cpu"

src = torch.tensor(
[5 + 1j, 2 + 2j], device=device, dtype=dtype, pin_memory=pin(device)
)
src_block = src.clone()
if src_conj:
src = src.conj()
src_block = src_block.conj()

src_ref = src.clone()

# The copy is cross device so parameterize on the source device and make sure the dst is the other one
dst_device = "cuda:0" if device == "cpu" else "cpu"
dst = torch.zeros_like(src, device=dst_device, pin_memory=pin(dst_device))
dst_block = torch.zeros_like(src, device=dst_device, pin_memory=pin(dst_device))
if dst_conj:
dst = dst.conj()
dst_block = dst_block.conj()

dst.copy_(src, non_blocking=True)
dst_block.copy_(src_block, non_blocking=False)

self.assertTrue(dst.is_conj() == dst_conj)
self.assertEqual(dst_block, dst)
self.assertEqual(src, src_ref)

@dtypes(*complex_types())
def test_all(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/120875
Expand Down
Loading
0