10BC0 [Distributed][CI] Rework continuous TestCase by kwen2501 · Pull Request #153653 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 2 additions & 23 deletions test/distributed/test_c10d_ops_nccl.py
10BC0
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import math
import os
import sys
import tempfile

import torch
import torch.distributed as c10d
Expand All @@ -30,9 +29,9 @@
requires_nccl,
requires_nccl_version,
sm_is_or_higher_than,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
TEST_WITH_DEV_DBG_ASAN,
Expand Down Expand Up @@ -1044,24 +1043,4 @@ def allgather_base(output_t, input_t):


if __name__ == "__main__":
if not torch.cuda.is_available():
sys.exit(TEST_SKIPS["no_cuda"].exit_code)

rank = int(os.getenv("RANK", -1))
world_size = int(os.getenv("WORLD_SIZE", -1))

if world_size == -1: # Not set by external launcher
world_size = torch.cuda.device_count()

if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
ProcessGroupNCCLOpTest.run_rank(rank, world_size)
else:
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
ProcessGroupNCCLOpTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)
run_tests()
45 changes: 8 additions & 37 deletions test/distributed/test_nvshmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@

import os
import sys
import tempfile

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.testing._internal.common_distributed import (
MultiProcContinousTest,
TEST_SKIPS,
)
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
)
Expand Down Expand Up @@ -47,28 +44,20 @@ def requires_nvshmem():

@requires_nvshmem()
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
def setUp(self) -> None:
super().setUp()
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# NOTE: required for nvshmem allocation
torch.empty(1, device=self.device)

# Required by MultiProcContinousTest
@classmethod
def backend_str(cls) -> str:
return "nccl"

@property
def world_size(self) -> int:
return device_module.device_count()

@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)

@skipIfRocm
def test_nvshmem_all_to_all(self) -> None:
self._init_device()

group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)

Expand All @@ -92,6 +81,8 @@ def test_nvshmem_all_to_all(self) -> None:

@skipIfRocm
def test_nvshmem_all_to_all_vdev(self) -> None:
self._init_device()

group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)

Expand Down Expand Up @@ -139,24 +130,4 @@ def test_nvshmem_all_to_all_vdev(self) -> None:


if __name__ == "__main__":
if not device_module.is_available():
sys.exit(TEST_SKIPS["no_cuda"].exit_code)

# If launched by torchrun, these values would have been set
rank = int(os.getenv("RANK", "-1"))
world_size = int(os.getenv("WORLD_SIZE", "-1"))

if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
NVSHMEMSymmetricMemoryTest.run_rank(rank, world_size)
else:
# No external launcher, spawn N processes
world_size = device_module.device_count()
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
NVSHMEMSymmetricMemoryTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)
run_tests()
16 changes: 16 additions & 0 deletions test/distributed/test_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Owner(s): ["oncall: distributed"]

from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import run_tests


class TestTemplate(MultiProcContinousTest):
def testABC(self):
print(f"rank {self.rank} of {self.world_size} testing ABC")

def testDEF(self):
print(f"rank {self.rank} of {self.world_size} testing DEF")


if __name__ == "__main__":
run_tests()
Loading
Loading
0