8000 [symm_mem] Fix nccl test for symm mem (#156752) · pytorch/pytorch@4585c33 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4585c33

Browse files
fduwjjpytorchmergebot
authored andcommitted
[symm_mem] Fix nccl test for symm mem (#156752)
Try not to call set_device to Fixes #156569 Pull Request resolved: #156752 Approved by: https://github.com/kwen2501
1 parent 7521cd9 commit 4585c33

File tree

4 files changed

+32
-38
lines changed

4 files changed

+32
-38
lines changed

.ci/pytorch/test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,8 @@ test_h100_distributed() {
329329
time python test/run_test.py --include distributed/_composable/fsdp/test_fully_shard_comm.py -k TestFullyShardAllocFromPG $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
330330
# symmetric memory test
331331
time python test/run_test.py --include distributed/test_symmetric_memory.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
332-
time python test/run_test.py --include distributed/test_nvshmem.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
332+
time TORCH_SYMMMEM=NVSHMEM python test/run_test.py --include distributed/test_nvshmem.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
333+
time TORCH_SYMMMEM=NCCL python test/run_test.py --include distributed/test_nccl.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
333334
assert_git_not_dirty
334335
}
335336

test/distributed/test_nccl.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
IS_WINDOWS,
1919
load_tests,
2020
NoTest,
21+
requires_cuda_p2p_access,
2122
run_tests,
2223
skip_but_pass_in_sandcastle_if,
2324
TEST_WITH_ROCM,
@@ -241,24 +242,17 @@ def test_reduce_scatter(self, device, dtype):
241242
self.assertEqual(outputs[i], expected[i])
242243

243244

244-
device_type = "cuda"
245-
device_module = torch.get_device_module(device_type)
246-
247-
245+
@requires_cuda_p2p_access()
248246
class NCCLSymmetricMemoryTest(MultiProcContinousTest):
249-
def _init_device(self) -> None:
250-
# TODO: relieve this (seems to hang if without)
251-
device_module.set_device(self.device)
252-
253247
@property
254248
def device(self) -> torch.device:
255-
return torch.device(device_type, self.rank)
249+
return torch.device("cuda", self.rank)
256250

257251
# To run this test, one needs to TORCH_SYMMMEM=NCCL when running the test.
258252
@skip_but_pass_in_sandcastle_if(TEST_WITH_ROCM, "Skip NCCL tests for ROCm")
259253
@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
260254
def test_nccl_symmem_alloc(self):
261-
self._init_device()
255+
torch.cuda.set_device(self.rank)
262256
c10d.all_reduce(torch.ones(1, device=self.device))
263257
group_name = c10d.group.WORLD.group_name
264258
symm_mem.enable_symm_mem_for_group(group_name)

test/distributed/test_symmetric_memory.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
MI300_ARCH,
3535
parametrize,
3636
requires_cuda,
37+
requires_cuda_p2p_access,
3738
run_tests,
3839
runOnRocmArch,
39-
skip_but_pass_in_sandcastle_if,
4040
skipIfRocm,
4141
TEST_WITH_ROCM,
4242
TestCase,
@@ -50,27 +50,6 @@
5050
device_module = torch.get_device_module(device_type)
5151

5252

53-
def requires_cuda_p2p_access():
54-
cuda_p2p_access_available = (
55-
torch.cuda.is_available()
56-
and torch.cuda.get_device_capability() >= (8, 0)
57-
and torch.cuda.device_count() >= 2
58-
)
59-
num_devices = torch.cuda.device_count()
60-
for i in range(num_devices - 1):
61-
for j in range(i + 1, num_devices):
62-
if not torch.cuda.can_device_access_peer(i, j):
63-
cuda_p2p_access_available = False
64-
break
65-
if not cuda_p2p_access_available:
66-
break
67-
68-
return skip_but_pass_in_sandcastle_if(
69-
not cuda_p2p_access_available,
70-
"cuda p2p access is not available",
71-
)
72-
73-
7453
@instantiate_parametrized_tests
7554
@requires_cuda_p2p_access()
7655
class SymmetricMemoryTest(MultiProcContinousTest):

torch/testing/_internal/common_utils.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2037,6 +2037,26 @@ def wrapper(*args, **kwargs):
20372037
return dec_fn(func)
20382038
return dec_fn
20392039

2040+
def requires_cuda_p2p_access():
2041+
cuda_p2p_access_available = (
2042+
torch.cuda.is_available()
2043+
and torch.cuda.get_device_capability() >= (8, 0)
2044+
and torch.cuda.device_count() >= 2
2045+
)
2046+
num_devices = torch.cuda.device_count()
2047+
for i in range(num_devices - 1):
2048+
for j in range(i + 1, num_devices):
2049+
if not torch.cuda.can_device_access_peer(i, j):
2050+
cuda_p2p_access_available = False
2051+
break
2052+
if not cuda_p2p_access_available:
2053+
break
2054+
2055+
return skip_but_pass_in_sandcastle_if(
2056+
not cuda_p2p_access_available,
2057+
"cuda p2p access is not available",
2058+
)
2059+
20402060
# Reverts the linalg backend back to default to make sure potential failures in one
20412061
# test do not affect other tests
20422062
def setLinalgBackendsToDefaultFinally(fn):
@@ -2551,18 +2571,18 @@ def __exit__(self, exc_type, exc_value, traceback):
25512571
msg = ("CUDA caching allocator reports a memory leak not " # type: ignore[possibly-undefined]
25522572
f"verified by the driver API in {self.name}! "
25532573
f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
2554-
f"and is now reported as {caching_allocator_mem_allocated} "
2574+
f"and is now reported as {caching_allocator_mem_allocated} " # type: ignore[possibly-undefined]
25552575
f"on device {i}. "
2556-
f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.")
2576+
f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") # type: ignore[possibly-undefined]
25572577
warnings.warn(msg)
2558-
elif caching_allocator_discrepancy and driver_discrepancy:
2578+
elif caching_allocator_discrepancy and driver_discrepancy: # type: ignore[possibly-undefined]
25592579
# A caching allocator discrepancy validated by the driver API is a
25602580
# failure (except on ROCm, see below)
25612581
msg = (f"CUDA driver API confirmed a leak in {self.name}! " # type: ignore[possibly-undefined]
25622582
f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
2563-
f"and is now reported as {caching_allocator_mem_allocated} "
2583+
f"and is now reported as {caching_allocator_mem_allocated} " # type: ignore[possibly-undefined]
25642584
f"on device {i}. "
2565-
f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.")
2585+
f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") # type: ignore[possibly-undefined]
25662586

25672587
raise RuntimeError(msg)
25682588

0 commit comments

Comments
 (0)
0