8000 Revert "[Distributed][CI] Rework continuous TestCase (#153653)" · pytorch/pytorch@674a85c · GitHub
[go: up one dir, main page]

Skip to content

Commit 674a85c

Browse files
Revert "[Distributed][CI] Rework continuous TestCase (#153653)"
This reverts commit 0d5c628. Reverted #153653 on behalf of https://github.com/kwen2501 due to More fixes needed ([comment](#153653 (comment)))
1 parent 0d5c628 commit 674a85c

File tree

4 files changed

+120
-200
lines changed

4 files changed

+120
-200
lines changed

test/distributed/_test_template.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

test/distributed/test_c10d_ops_nccl.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import math
1212
import os
1313
import sys
14+
import tempfile
1415

1516
import torch
1617
import torch.distributed as c10d
@@ -29,9 +30,9 @@
2930
requires_nccl,
3031
requires_nccl_version,
3132
sm_is_or_higher_than,
33+
TEST_SKIPS,
3234
)
3335
from torch.testing._internal.common_utils import (
34-
run_tests,
3536
skip_but_pass_in_sandcastle_if,
3637
skipIfRocm,
3738
TEST_WITH_DEV_DBG_ASAN,
@@ -1043,4 +1044,24 @@ def allgather_base(output_t, input_t):
10431044

10441045

10451046
if __name__ == "__main__":
1046-
run_tests()
1047+
if not torch.cuda.is_available():
1048+
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
1049+
1050+
rank = int(os.getenv("RANK", -1))
1051+
world_size = int(os.getenv("WORLD_SIZE", -1))
1052+
1053+
if world_size == -1: # Not set by external launcher
1054+
world_size = torch.cuda.device_count()
1055+
1056+
if rank != -1:
1057+
# Launched with torchrun or other multi-proc launchers. Directly run the test.
1058+
ProcessGroupNCCLOpTest.run_rank(rank, world_size)
1059+
else:
1060+
# Launched as a single process. Spawn subprocess to run the tests.
1061+
# Also need a rendezvous file for `init_process_group` purpose.
1062+
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
1063+
torch.multiprocessing.spawn(
1064+
ProcessGroupNCCLOpTest.run_rank,
1065+
nprocs=world_size,
1066+
args=(world_size, rdvz_file),
1067+
)

test/distributed/test_nvshmem.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
import os
99
import sys
10+
import tempfile
1011

1112
import torch
1213
import torch.distributed as dist
1314
import torch.distributed._symmetric_memory as symm_mem
14-
from torch.testing._internal.common_distributed import MultiProcContinousTest
15+
from torch.testing._internal.common_distributed import (
16+
MultiProcContinousTest,
17+
TEST_SKIPS,
18+
)
1519
from torch.testing._internal.common_utils import (
16-
run_tests,
1720
skip_but_pass_in_sandcastle_if,
1821
skipIfRocm,
1922
)
@@ -44,20 +47,28 @@ def requires_nvshmem():
4447

4548
@requires_nvshmem()
4649
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
47-
def _init_device(self) -> None:
50+
def setUp(self) -> None:
51+
super().setUp()
4852
# TODO: relieve this (seems to hang if without)
4953
device_module.set_device(self.device)
5054
# NOTE: required for nvshmem allocation
5155
torch.empty(1, device=self.device)
5256

57+
# Required by MultiProcContinousTest
58+
@classmethod
59+
def backend_str(cls) -> str:
60+
return "nccl"
61+
62+
@property
63+
def world_size(self) -> int:
64+
return device_module.device_count()
65+
5366
@property
5467
def device(self) -> torch.device:
5568
return torch.device(device_type, self.rank)
5669

5770
@skipIfRocm
5871
def test_nvshmem_all_to_all(self) -> None:
59-
self._init_device()
60-
6172
group_name = dist.group.WORLD.group_name
6273
symm_mem.enable_symm_mem_for_group(group_name)
6374

@@ -81,8 +92,6 @@ def test_nvshmem_all_to_all(self) -> None:
8192

8293
@skipIfRocm
8394
def test_nvshmem_all_to_all_vdev(self) -> None:
84-
self._init_device()
85-
8695
group_name = dist.group.WORLD.group_name
8796
symm_mem.enable_symm_mem_for_group(group_name)
8897

@@ -130,4 +139,24 @@ def test_nvshmem_all_to_all_vdev(self) -> None:
130139

131140

132141
if __name__ == "__main__":
133-
run_tests()
142+
if not device_module.is_available():
143+
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
144+
145+
# If launched by torchrun, these values would have been set
146+
rank = int(os.getenv("RANK", "-1"))
147+
world_size = int(os.getenv("WORLD_SIZE", "-1"))
148+
149+
if rank != -1:
150+
# Launched with torchrun or other multi-proc launchers. Directly run the test.
151+
NVSHMEMSymmetricMemoryTest.run_rank(rank, world_size)
152+
else:
153+
# No external launcher, spawn N processes
154+
world_size = device_module.device_count()
155+
# Launched as a single process. Spawn subprocess to run the tests.
156+
# Also need a rendezvous file for `init_process_group` purpose.
157+
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
158+
torch.multiprocessing.spawn(
159+
NVSHMEMSymmetricMemoryTest.run_rank,
160+
nprocs=world_size,
161+
args=(world_size, rdvz_file),
162+
)

0 commit comments

Comments
 (0)
0