E5E7 [Distributed][CI] Rework continuous TestCase by kwen2501 · Pull Request #153653 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,9 @@ exclude_patterns = [
'test/distributed/optim/test_apply_optimizer_in_backward.py',
'test/distributed/optim/test_named_optimizer.py',
'test/distributed/test_c10d_spawn.py',
'test/distributed/test_c10d_ops_nccl.py',
'test/distributed/test_symmetric_memory.py',
'test/distributed/test_nvshmem.py',
'test/distributed/test_collective_utils.py',
'test/distributions/test_distributions.py',
'test/inductor/test_aot_inductor_utils.py',
Expand Down
39 changes: 19 additions & 20 deletions test/distributed/test_c10d_ops_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,23 +1045,22 @@ def allgather_base(output_t, input_t):

if __name__ == "__main__":
if not torch.cuda.is_available():
sys.exit(TEST_SKIPS["no_cuda"].exit_code)

rank = int(os.getenv("RANK", -1))
world_size = int(os.getenv("WORLD_SIZE", -1))

if world_size == -1: # Not set by external launcher
world_size = torch.cuda.device_count()

if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
ProcessGroupNCCLOpTest.run_rank(rank, world_size)
else:
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
ProcessGroupNCCLOpTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)
os._exit(TEST_SKIPS["no_cuda"].exit_code)

# Use device count as world size
world_size = torch.cuda.device_count()
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
# Spawn subprocess to run the tests.
# `run_tests()` will be called under `run_rank`
torch.multiprocessing.spawn(
MultiProcContinousTest.run_rank, # entry point
nprocs=world_size,
args=(world_size, rdvz_file),
)

# Clear up the rendezvous file
try:
os.remove(rdvz_file)
except OSError:
pass
Loading
Loading
0