77
88import os
99import sys
10+ import tempfile
1011
1112import torch
1213import torch .distributed as dist
1314import torch .distributed ._symmetric_memory as symm_mem
14- from torch .testing ._internal .common_distributed import MultiProcContinousTest
15+ from torch .testing ._internal .common_distributed import (
16+ MultiProcContinousTest ,
17+ TEST_SKIPS ,
18+ )
1519from torch .testing ._internal .common_utils import (
16- run_tests ,
1720 skip_but_pass_in_sandcastle_if ,
1821 skipIfRocm ,
1922)
@@ -44,20 +47,28 @@ def requires_nvshmem():
4447
4548@requires_nvshmem ()
4649class NVSHMEMSymmetricMemoryTest (MultiProcContinousTest ):
47- def _init_device (self ) -> None :
50+ def setUp (self ) -> None :
51+ super ().setUp ()
4852 # TODO: relieve this (seems to hang if without)
4953 device_module .set_device (self .device )
5054 # NOTE: required for nvshmem allocation
5155 torch .empty (1 , device = self .device )
5256
57+ # Required by MultiProcContinousTest
58+ @classmethod
59+ def backend_str (cls ) -> str :
60+ return "nccl"
61+
62+ @property
63+ def world_size (self ) -> int :
64+ return device_module .device_count ()
65+
5366 @property
5467 def device (self ) -> torch .device :
5568 return torch .device (device_type , self .rank )
5669
5770 @skipIfRocm
5871 def test_nvshmem_all_to_all (self ) -> None :
59- self ._init_device ()
60-
6172 group_name = dist .group .WORLD .group_name
6273 symm_mem .enable_symm_mem_for_group (group_name )
6374
@@ -81,8 +92,6 @@ def test_nvshmem_all_to_all(self) -> None:
8192
8293 @skipIfRocm
8394 def test_nvshmem_all_to_all_vdev (self ) -> None :
84- self ._init_device ()
85-
8695 group_name = dist .group .WORLD .group_name
8796 symm_mem .enable_symm_mem_for_group (group_name )
8897
@@ -130,4 +139,24 @@ def test_nvshmem_all_to_all_vdev(self) -> None:
130139
131140
132141if __name__ == "__main__" :
133- run_tests ()
142+ if not device_module .is_available ():
143+ sys .exit (TEST_SKIPS ["no_cuda" ].exit_code )
144+
145+ # If launched by torchrun, these values would have been set
146+ rank = int (os .getenv ("RANK" , "-1" ))
147+ world_size = int (os .getenv ("WORLD_SIZE" , "-1" ))
148+
149+ if rank != - 1 :
150+ # Launched with torchrun or other multi-proc launchers. Directly run the test.
151+ NVSHMEMSymmetricMemoryTest .run_rank (rank , world_size )
152+ else :
153+ # No external launcher, spawn N processes
154+ world_size = device_module .device_count ()
155+ # Launched as a single process. Spawn subprocess to run the tests.
156+ # Also need a rendezvous file for `init_process_group` purpose.
157+ rdvz_file = tempfile .NamedTemporaryFile (delete = False ).name
158+ torch .multiprocessing .spawn (
159+ NVSHMEMSymmetricMemoryTest .run_rank ,
160+ nprocs = world_size ,
161+ args = (world_size , rdvz_file ),
162+ )
0 commit comments