8000 All graph tests in test_cuda.py pass except for test_graph_cudnn_dropout · pytorch/pytorch@c80a05f · GitHub
[go: up one dir, main page]

Skip to content

Commit c80a05f

Browse files
committed
All graph tests in test_cuda.py pass except for test_graph_cudnn_dropout
1 parent 368a0de commit c80a05f

File tree

3 files changed

+71
-24
lines changed

3 files changed

+71
-24
lines changed

c10/cuda/CUDACachingAllocator.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ C10_CUDA_API void notifyCaptureBegin(int device, CaptureId_t graph_id, MempoolId
141141
C10_CUDA_API void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id); \
142142
C10_CUDA_API void notifyCaptureEnded(int device, CaptureId_t graph_id); \
143143
C10_CUDA_API void notifyCaptureDestroy(int device, MempoolId_t mempool_id); \
144-
C10_CUDA_API std::mutex* getFreeMutex();
144+
C10_CUDA_API std::mutex* getFreeMutex(); \
145+
C10_CUDA_API std::shared_ptr<void> getIpcDevPtr(std::string handle);
145146

146147
// Not meant to be called directly by clients.
147148
// Maybe make "CUDACachingAllocator" a class or struct, and make these private members?
@@ -303,7 +304,12 @@ inline std::mutex* getFreeMutex() {
303304
}
304305

305306
// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE
306-
C10_CUDA_API std::shared_ptr<void> getIpcDevPtr(std::string handle);
307+
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
308+
static auto f = (std::strcmp(allocatorBackend(), "native") == 0) ?
309+
THC::getIpcDevPtr : CudaMallocAsync::getIpcDevPtr;
310+
return f(handle);
311+
312+
}
307313

308314
} // namespace CUDACachingAllocator
309315
} // namespace cuda

c10/cuda/CUDAMallocAsyncAllocator.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ inline void lazy_init_device(int device) {
173173

174174
pytorch_used_bytes[device] = 0;
175175
pytorch_memory_limits[device] = UINT64_MAX;
176+
177+
devs_initialized_flags[device] = true;
176178
}
177179
}
178180

@@ -281,8 +283,9 @@ void free(void* ptr) {
281283
return;
282284
}
283285
} else if (C10_UNLIKELY(it->second.captured)) {
284-
TORCH_WARN("Attempting uncaptured free of a captured allocation. "
285-
"This is technically allowed, but may indicate you are losing "
286+
TORCH_WARN("Attempting uncaptured free of a captured allocation with address ",
287+
ptr,
288+
"\nThis is technically allowed, but may indicate you are losing "
286289
"the last user-visible tensor through which the allocation can "
287290
"be accessed, so you'll have no way to view the data after "
288291
"future replays of the owning graph.");
@@ -323,7 +326,7 @@ void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) {
323326
auto inserted = ptr_info.emplace(*devPtr, PtrUsage(size, capture_underway));
324327
TORCH_INTERNAL_ASSERT(inserted.second,
325328
"address returned by cudaMallocAsync already exists "
326-
"in usage_streams_each_ptr");
329+
"in ptr_info");
327330

328331
inserted.first->second.usage_streams.emplace_back(stream, device);
329332

@@ -371,7 +374,7 @@ void init(int dev_count) {
371374
TORCH_INTERNAL_ASSERT(!called, "init called twice");
372375
std::lock_guard<std::mutex> lk(general_mutex);
373376
device_count = dev_count;
374-
devs_initialized_flags.resize(dev_count, 0);
377+
devs_initialized_flags.resize(dev_count, false);
375378
dummy_unifying_free_streams.resize(dev_count);
376379
pytorch_used_bytes.resize(dev_count);
377380
pytorch_memory_limits.resize(dev_count);
@@ -474,6 +477,12 @@ std::mutex* getFreeMutex() {
474477
return &general_mutex;
475478
}
476479

480+
std::shared_ptr<void> getIpcDevPtr(std::string handle) {
481+
TORCH_CHECK(false,
482+
"cudaMallocAsync does not yet support getIpcDevPtr. "
483+
"If you need it, please file an issue describing your use case.");
484+
}
485+
477486
// Collects stats for device.
478487
// If device hasn't been used yet, returns 0s without creating a context.
479488
DeviceStats getDeviceStats(int device) {
@@ -740,6 +749,9 @@ void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
740749
std::mutex* getFreeMutex() {
741750
NOT_AVAILABLE("getFreeMutex")
742751
}
752+
std::shared_ptr<void> getIpcDevPtr(std::string handle) {
753+
NOT_AVAILABLE("getIpcDevPtr")
754+
}
743755

744756
#endif
745757

test/test_cuda.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3210,10 +3210,17 @@ def run(op, kwargs):
32103210
for op, kwargs in ops_with_kwargs:
32113211
run(op, kwargs)
32123212

3213+
def _using_cudaMallocAsync(self):
3214+
import os
3215+
alloc_conf = os.getenv("PYTORCH_CUDA_ALLOC_CONF")
3216+
return (alloc_conf is not None) and ("backend:cudaMallocAsync" in alloc_conf)
3217+
32133218
@unittest.skipIf((not TEST_CUDA) or
32143219
TEST_WITH_ROCM or
32153220
int(torch.version.cuda.sp D95F lit(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
32163221
def test_graph_rng_distributions(self):
3222+
using_cudaMallocAsync = self._using_cudaMallocAsync()
3223+
32173224
size = 10000
32183225
input = torch.rand((size,), device="cuda", dtype=torch.float)
32193226
alloc = torch.empty((size,), device="cuda", dtype=torch.float)
@@ -3280,11 +3287,19 @@ def run(module, op, args, kwargs):
32803287
g.capture_end()
32813288
torch.cuda.current_stream().wait_stream(stream)
32823289

3283-
try:
3284-
self.assertNotEqual(control1, t1)
3285-
self.assertNotEqual(control2, t2)
3286-
except Exception as e:
3287-
raise RuntimeError("Failed on " + module + "." + op) from e
3290+
if not using_cudaMallocAsync:
3291+
# Makes sure values haven't been populated yet
3292+
# (in other words, makes sure capture didn't actually run ops).
3293+
# We can only try this with the native allocator, for which captured
3294+
# addresses are already backed by cudaMalloced memory.
3295+
# If we try it with cudaMallocAsync, CUDA won't event consider
3296+
# the captured addresses allocated until replay(), and if we
3297+
# access them before replay() we get IMAs.
3298+
try:
3299+
self.assertNotEqual(control1, t1)
3300+
self.assertNotEqual(control2, t2)
3301+
except Exception as e:
3302+
raise RuntimeError("Failed on " + module + "." + op) from e
32883303

32893304
# Runs a dummy op prelude, as for controls, to make sure replay()
32903305
# picks up the dummy op's state increment.
@@ -3319,6 +3334,7 @@ def run(module, op, args, kwargs):
33193334
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
33203335
def test_graph_two_successive(self):
33213336
torch.cuda.empty_cache()
3337+
using_cudaMallocAsync = self._using_cudaMallocAsync()
33223338

33233339
size = 1000
33243340
kSmallBuffer = 2097152
@@ -3366,24 +3382,28 @@ def func_with_temps(t, val):
33663382
self.assertEqual(b.sum().item(), size * 3070)
33673383
self.assertEqual(c.sum().item(), size * 442)
33683384

3369-
if share_mem != "Don't share":
3370-
self.assertEqual(reserved_no_sharing - torch.cuda.memory_stats()["reserved_bytes.all.current"],
3371-
kSmallBuffer)
3372-
else:
3373-
reserved_no_sharing = torch.cuda.memory_stats()["reserved_bytes.all.current"]
3385+
if not using_cudaMallocAsync:
3386+
# These stat checks are specific to the native allocator.
3387+
if share_mem != "Don't share":
3388+
self.assertEqual(reserved_no_sharing - torch.cuda.memory_stats()["reserved_bytes.all.current"],
3389+
kSmallBuffer)
3390+
else:
3391+
reserved_no_sharing = torch.cuda.memory_stats()["reserved_bytes.all.current"]
33743392

33753393
del a, b, c, g0, g1
33763394
# Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them.
33773395
torch.cuda.synchronize()
33783396
torch.cuda.empty_cache()
33793397

3380-
@unittest.skip("Temporarily disabled due to a graphs bug in libcuda.so, " +
3381-
"see https://github.com/pytorch/pytorch/pull/57556")
33823398
@unittest.skipIf((not TEST_CUDA) or
33833399
TEST_WITH_ROCM or
33843400
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
3401+
@unittest.skipIf(int(torch.version.cuda.split(".")[1]) < 4,
3402+
"Graph bindings disallow concurrent replay for CUDA < 11.4, see " +
3403+
"https://github.com/pytorch/pytorch/pull/57556")
33853404
def test_graph_concurrent_replay(self):
33863405
torch.cuda.empty_cache()
3406+
using_cudaMallocAsync = self._using_cudaMallocAsync()
33873407

33883408
size = 1000000 # largeish to help expose race conditions
33893409

@@ -3432,12 +3452,16 @@ def func_with_temps(t, val):
34323452
torch.cuda.current_stream().wait_stream(s0)
34333453
torch.cuda.current_stream().wait_stream(s1)
34343454

3435-
if share_mem != "Don't share":
3436-
# Confirms concurrent replays using the same mempool corrupted each other.
3455+
if (not using_cudaMallocAsync) and (share_mem != "Don't share"):
3456+
# If we used the native allocator and shared mempools,
3457+
# we expect the concurrent replays corrupted each other.
34373458
self.assertNotEqual(b.sum().item(), size * 94)
34383459
self.assertNotEqual(c.sum().item(), size * 156)
34393460
else:
3440-
# Confirms concurrent replays using different mempools did not corrupt each other.
3461+
# If we EITHER
3462+
# - used the native allocator without sharing mempools, OR
3463+
# - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe
3464+
# we don't expect memory corruption.
34413465
self.assertEqual(b.sum().item(), size * 94)
34423466
self.assertEqual(c.sum().item(), size * 156)
34433467

@@ -3451,6 +3475,7 @@ def func_with_temps(t, val):
34513475
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
34523476
def test_graph_three_successive(self):
34533477
torch.cuda.empty_cache()
3478+
using_cudaMallocAsync = self._using_cudaMallocAsync()
34543479

34553480
size = 1000
34563481

@@ -3497,9 +3522,10 @@ def test_graph_three_successive(self):
34973522
g2.replay()
34983523
g1.replay()
34993524

3500-
# If share_mem is True, g2's capture should have reused c's memory for f. We replayed g2 then g1,
3501-
# so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3".
3502-
self.assertEqual(e.sum().item(), size * (7 + 3) if share_mem != "Don't share" else size * 5)
3525+
expect_corruption = (not using_cudaMallocAsync) and (share_mem != "Don't share")
3526+
# If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f.
3527+
# We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3".
3528+
self.assertEqual(e.sum().item(), size * (7 + 3) if expect_corruption else size * 5)
35033529
self.assertEqual(f.sum().item(), size * 7)
35043530

35053531
del a, b, d, e, f, g0, g1, g2
@@ -3511,6 +3537,9 @@ def test_graph_three_successive(self):
35113537
TEST_WITH_ROCM or
35123538
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
35133539
def test_graph_memory_stats_and_use_result_after_destroy_graph(self):
3540+
if self._using_cudaMallocAsync():
3541+
return
3542+
35143543
kSmallSize = 1048576
35153544
kSmallBuffer = 2097152
35163545
kLargeBuffer = 20971520

0 commit comments

Comments
 (0)
0