@@ -3210,10 +3210,17 @@ def run(op, kwargs):
3210
3210
for op , kwargs in ops_with_kwargs :
3211
3211
run (op , kwargs )
3212
3212
3213
+ def _using_cudaMallocAsync (self ):
3214
+ import os
3215
+ alloc_conf = os .getenv ("PYTORCH_CUDA_ALLOC_CONF" )
3216
+ return (alloc_conf is not None ) and ("backend:cudaMallocAsync" in alloc_conf )
3217
+
3213
3218
@unittest .skipIf ((not TEST_CUDA ) or
3214
3219
TEST_WITH_ROCM or
3215
3220
int (torch .version .cuda .sp
D95F
lit ("." )[0 ]) < 11 , "CUDA >= 11.0 required for graphs" )
3216
3221
def test_graph_rng_distributions (self ):
3222
+ using_cudaMallocAsync = self ._using_cudaMallocAsync ()
3223
+
3217
3224
size = 10000
3218
3225
input = torch .rand ((size ,), device = "cuda" , dtype = torch .float )
3219
3226
alloc = torch .empty ((size ,), device = "cuda" , dtype = torch .float )
@@ -3280,11 +3287,19 @@ def run(module, op, args, kwargs):
3280
3287
g .capture_end ()
3281
3288
torch .cuda .current_stream ().wait_stream (stream )
3282
3289
3283
- try :
3284
- self .assertNotEqual (control1 , t1 )
3285
- self .assertNotEqual (control2 , t2 )
3286
- except Exception as e :
3287
- raise RuntimeError ("Failed on " + module + "." + op ) from e
3290
+ if not using_cudaMallocAsync :
3291
+ # Makes sure values haven't been populated yet
3292
+ # (in other words, makes sure capture didn't actually run ops).
3293
+ # We can only try this with the native allocator, for which captured
3294
+ # addresses are already backed by cudaMalloced memory.
3295
+ # If we try it with cudaMallocAsync, CUDA won't event consider
3296
+ # the captured addresses allocated until replay(), and if we
3297
+ # access them before replay() we get IMAs.
3298
+ try :
3299
+ self .assertNotEqual (control1 , t1 )
3300
+ self .assertNotEqual (control2 , t2 )
3301
+ except Exception as e :
3302
+ raise RuntimeError ("Failed on " + module + "." + op ) from e
3288
3303
3289
3304
# Runs a dummy op prelude, as for controls, to make sure replay()
3290
3305
# picks up the dummy op's state increment.
@@ -3319,6 +3334,7 @@ def run(module, op, args, kwargs):
3319
3334
int (torch .version .cuda .split ("." )[0 ]) < 11 , "CUDA >= 11.0 required for graphs" )
3320
3335
def test_graph_two_successive (self ):
3321
3336
torch .cuda .empty_cache ()
3337
+ using_cudaMallocAsync = self ._using_cudaMallocAsync ()
3322
3338
3323
3339
size = 1000
3324
3340
kSmallBuffer = 2097152
@@ -3366,24 +3382,28 @@ def func_with_temps(t, val):
3366
3382
self .assertEqual (b .sum ().item (), size * 3070 )
3367
3383
self .assertEqual (c .sum ().item (), size * 442 )
3368
3384
3369
- if share_mem != "Don't share" :
3370
- self .assertEqual (reserved_no_sharing - torch .cuda .memory_stats ()["reserved_bytes.all.current" ],
3371
- kSmallBuffer )
3372
- else :
3373
- reserved_no_sharing = torch .cuda .memory_stats ()["reserved_bytes.all.current" ]
3385
+ if not using_cudaMallocAsync :
3386
+ # These stat checks are specific to the native allocator.
3387
+ if share_mem != "Don't share" :
3388
+ self .assertEqual (reserved_no_sharing - torch .cuda .memory_stats ()["reserved_bytes.all.current" ],
3389
+ kSmallBuffer )
3390
+ else :
3391
+ reserved_no_sharing = torch .cuda .memory_stats ()["reserved_bytes.all.current" ]
3374
3392
3375
3393
del a , b , c , g0 , g1
3376
3394
# Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them.
3377
3395
torch .cuda .synchronize ()
3378
3396
torch .cuda .empty_cache ()
3379
3397
3380
- @unittest .skip ("Temporarily disabled due to a graphs bug in libcuda.so, " +
3381
- "see https://github.com/pytorch/pytorch/pull/57556" )
3382
3398
@unittest .skipIf ((not TEST_CUDA ) or
3383
3399
TEST_WITH_ROCM or
3384
3400
int (torch .version .cuda .split ("." )[0 ]) < 11 , "CUDA >= 11.0 required for graphs" )
3401
+ @unittest .skipIf (int (torch .version .cuda .split ("." )[1 ]) < 4 ,
3402
+ "Graph bindings disallow concurrent replay for CUDA < 11.4, see " +
3403
+ "https://github.com/pytorch/pytorch/pull/57556" )
3385
3404
def test_graph_concurrent_replay (self ):
3386
3405
torch .cuda .empty_cache ()
3406
+ using_cudaMallocAsync = self ._using_cudaMallocAsync ()
3387
3407
3388
3408
size = 1000000 # largeish to help expose race conditions
3389
3409
@@ -3432,12 +3452,16 @@ def func_with_temps(t, val):
3432
3452
torch .cuda .current_stream ().wait_stream (s0 )
3433
3453
torch .cuda .current_stream ().wait_stream (s1 )
3434
3454
3435
- if share_mem != "Don't share" :
3436
- # Confirms concurrent replays using the same mempool corrupted each other.
3455
+ if (not using_cudaMallocAsync ) and (share_mem != "Don't share" ):
3456
+ # If we used the native allocator and shared mempools,
3457
+ # we expect the concurrent replays corrupted each other.
3437
3458
self .assertNotEqual (b .sum ().item (), size * 94 )
3438
3459
self .assertNotEqual (c .sum ().item (), size * 156 )
3439
3460
else :
3440
- # Confirms concurrent replays using different mempools did not corrupt each other.
3461
+ # If we EITHER
3462
+ # - used the native allocator without sharing mempools, OR
3463
+ # - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe
3464
+ # we don't expect memory corruption.
3441
3465
self .assertEqual (b .sum ().item (), size * 94 )
3442
3466
self .assertEqual (c .sum ().item (), size * 156 )
3443
3467
@@ -3451,6 +3475,7 @@ def func_with_temps(t, val):
3451
3475
int (torch .version .cuda .split ("." )[0 ]) < 11 , "CUDA >= 11.0 required for graphs" )
3452
3476
def test_graph_three_successive (self ):
3453
3477
torch .cuda .empty_cache ()
3478
+ using_cudaMallocAsync = self ._using_cudaMallocAsync ()
3454
3479
3455
3480
size = 1000
3456
3481
@@ -3497,9 +3522,10 @@ def test_graph_three_successive(self):
3497
3522
g2 .replay ()
3498
3523
g1 .replay ()
3499
3524
3500
- # If share_mem is True, g2's capture should have reused c's memory for f. We replayed g2 then g1,
3501
- # so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3".
3502
- self .assertEqual (e .sum ().item (), size * (7 + 3 ) if share_mem != "Don't share" else size * 5 )
3525
+ expect_corruption = (not using_cudaMallocAsync ) and (share_mem != "Don't share" )
3526
+ # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f.
3527
+ # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3".
3528
+ self .assertEqual (e .sum ().item (), size * (7 + 3 ) if expect_corruption else size * 5 )
3503
3529
self .assertEqual (f .sum ().item (), size * 7 )
3504
3530
3505
3531
del a , b , d , e , f , g0 , g1 , g2
@@ -3511,6 +3537,9 @@ def test_graph_three_successive(self):
3511
3537
TEST_WITH_ROCM or
3512
3538
int (torch .version .cuda .split ("." )[0 ]) < 11 , "CUDA >= 11.0 required for graphs" )
3513
3539
def test_graph_memory_stats_and_use_result_after_destroy_graph (self ):
3540
+ if self ._using_cudaMallocAsync ():
3541
+ return
3542
+
3514
3543
kSmallSize = 1048576
3515
3544
kSmallBuffer = 2097152
3516
3545
kLargeBuffer = 20971520
0 commit comments