8000 test_cuda.py passes on my machine · pytorch/pytorch@c2d84ea · GitHub
[go: up one dir, main page]

Skip to content

Commit c2d84ea

Browse files
committed
test_cuda.py passes on my machine
1 parent c90a3a0 commit c2d84ea

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
#include <ATen/cuda/CUDAContext.h>
66
#include <ATen/cuda/CUDAEvent.h>
77
#include <ATen/cuda/PeerToPeerAccess.h>
8-
#include <c10/cuda/CUDAStream.h>
98
#include <ATen/native/Copy.h>
109
#include <ATen/native/TensorIterator.h>
1110
#include <ATen/native/cuda/Loops.cuh>
1211

12+
#include <c10/cuda/CUDACachingAllocator.h>
13+
#include <c10/cuda/CUDAStream.h>
14+
1315
namespace at {
1416
namespace native {
1517

@@ -41,7 +43,9 @@ void neg_conj_kernel_cuda(TensorIteratorBase &iter) {
4143
using namespace at::cuda;
4244

4345
// device-to-device copy, does type conversion
44-
void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
46+
void copy_device_to_device(TensorIterator& iter,
47+
bool non_blocking,
48+
bool p2p_enabled) {
4549
int64_t numel = iter.numel();
4650

4751
// We can memcpy the memory if both tensors have the same type AND both
@@ -82,11 +86,29 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
8286
void *src = iter.data_ptr(1);
8387
size_t size = numel * iter.element_size(0);
8488
if (src != dst || src_device != dst_device) {
85-
// Perform the copy
86-
AT_CUDA_CHECK(cudaMemcpyAsync(
87-
dst, src, size,
88-
cudaMemcpyDeviceToDevice,
89-
copy_stream));
89+
#if CUDA_VERSION > 11040
90+
// Due to bizarre cuda driver intricacies, copies of
91+
// cudaMallocAsynced memory between devices that aren't
92+
// peer-to-peer-capable need "cudaMemcpyPeerAsync".
93+
static bool using_cudaMallocAsync = std::strcmp(CUDACachingAllocator::allocatorBackend(),
94+
"cudaMallocAsync") == 0;
95+
bool needs_MemcpyPeer = (src_device != dst_device &&
96+
using_cudaMallocAsync &&
97+
!p2p_enabled);
98+
if (needs_MemcpyPeer) {
99+
AT_CUDA_CHECK(cudaMemcpyPeerAsync(
100+
dst, dst_device.index(),
101+
src, src_device.index(),
102+
size, copy_stream));
103+
} else {
104+
#endif
105+
AT_CUDA_CHECK(cudaMemcpyAsync(
106+
dst, src, size,
107+
cudaMemcpyDeviceToDevice,
108+
copy_stream));
109+
#if CUDA_VERSION > 11040
110+
}
111+
#endif
90112
}
91113
} else {
92114
if (same_neg) {
@@ -199,7 +221,7 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
199221

200222
// Copy on GPU (or between GPUs)
201223
if (dst_device.is_cuda() && src_device.is_cuda()) {
202-
copy_device_to_device(iter, non_blocking);
224+
copy_device_to_device(iter, non_blocking, p2p_enabled);
203225
return;
204226
}
205227

0 commit comments

Comments
 (0)
0