diff --git a/test/distributed/_test_template.py b/test/distributed/_test_template.py new file mode 100644 index 0000000000000..74a38f7136482 --- /dev/null +++ b/test/distributed/_test_template.py @@ -0,0 +1,16 @@ +# Owner(s): ["oncall: distributed"] + +from torch.testing._internal.common_distributed import MultiProcContinousTest +from torch.testing._internal.common_utils import run_tests + + +class TestTemplate(MultiProcContinousTest): + def testABC(self): + print(f"rank {self.rank} of {self.world_size} testing ABC") + + def testDEF(self): + print(f"rank {self.rank} of {self.world_size} testing DEF") + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipelining/model_registry.py b/test/distributed/pipelining/model_registry.py index 05d4e54176f98..4e0b359a70480 100644 --- a/test/distributed/pipelining/model_registry.py +++ b/test/distributed/pipelining/model_registry.py @@ -7,13 +7,16 @@ class ExampleCode(torch.nn.Module): - def __init__(self, d_hid): + def __init__(self, d_hid, splits=2): + assert splits <= 4 super().__init__() + self.splits = splits self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False)) self.lin0 = torch.nn.Linear(d_hid, d_hid) self.lin1 = torch.nn.Linear(d_hid, d_hid) + self.lin2 = torch.nn.Linear(d_hid, d_hid) def forward(self, x): x = torch.mm(x, self.mm_param0) @@ -24,8 +27,14 @@ def forward(self, x): pipe_split() x = torch.relu(x) + a_constant x = torch.mm(x, self.mm_param1) - x = self.lin1(x) - x = torch.relu(x) + if self.splits > 2: + pipe_split() + x = self.lin1(x) + x = torch.relu(x) + if self.splits > 3: + pipe_split() + x = self.lin2(x) + x = torch.relu(x) return x @@ -33,12 +42,16 @@ class ModelWithKwargs(torch.nn.Module): DEFAULT_DHID = 512 DEFAULT_BATCH_SIZE = 256 - def __init__(self, d_hid: int = DEFAULT_DHID): + def __init__(self, d_hid: int = DEFAULT_DHID, splits=2): + assert splits <= 4 super().__init__() + self.splits = splits self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.lin0 = torch.nn.Linear(d_hid, d_hid) self.lin1 = torch.nn.Linear(d_hid, d_hid) + self.lin2 = torch.nn.Linear(d_hid, d_hid) + self.lin3 = torch.nn.Linear(d_hid, d_hid) def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)): x = torch.mm(x, self.mm_param0) @@ -49,6 +62,14 @@ def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)): x = torch.mm(x, self.mm_param1) x = self.lin1(x) x = torch.relu(x) + if self.splits > 2: + pipe_split() + x = self.lin2(x) + x = torch.relu(x) + if self.splits > 3: + pipe_split() + x = self.lin3(x) + x = torch.relu(x) return x diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index 8491881f7fe23..a8faa6ed12660 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -2,8 +2,6 @@ # Owner(s): ["oncall: distributed"] import copy import logging -import os -import sys import tempfile from model_registry import ModelWithKwargs, MultiMLP, MultiMLPWithDw @@ -37,6 +35,7 @@ check_leaked_tensors, instantiate_parametrized_tests, parametrize, + run_tests, skip_but_pass_in_sandcastle_if, ) @@ -48,6 +47,8 @@ torch.manual_seed(0) +device_type = "cuda" + class ScheduleTest(MultiProcContinousTest): @classmethod @@ -55,15 +56,9 @@ def backend_str(cls) -> str: # Testing with NCCL backend return "nccl" - @classmethod - def setUpClass(cls): - """ - Class-scope test fixture. Run once for entire test class, before any test starts. - Set up the device. - """ - super().setUpClass() - dev_id = cls.rank % torch.cuda.device_count() - cls.device = torch.device(f"cuda:{dev_id}") + @property + def device(self) -> torch.device: + return torch.device(device_type, self.rank) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -77,7 +72,7 @@ def test_forward_only(self, ScheduleClass): x = torch.randn(batch_size, d_hid, device=self.device) x_clone = x.clone() - num_microbatches = 4 + num_microbatches = 2 * self.world_size x_mb = x.chunk(num_microbatches)[0] # Create a pipeline @@ -159,6 +154,12 @@ def test_multi_iter(self, ScheduleClass): @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_kwargs_with_tracer(self, ScheduleClass): + # Model has two stages only, thus limiting group size to 2 + group_size = 2 + group = dist.new_group(list(range(group_size))) + if self.rank >= group_size: + return + mod = ModelWithKwargs(d_hid) mod.to(self.device) @@ -180,6 +181,7 @@ def test_kwargs_with_tracer(self, ScheduleClass): stage = pipe.build_stage( self.rank, self.device, + group=group, ) # Attach to a schedule @@ -188,16 +190,16 @@ def test_kwargs_with_tracer(self, ScheduleClass): # Run if self.rank == 0: schedule.step(x, y=y) - elif self.rank == self.world_size - 1: + elif self.rank == group_size - 1: losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() - dist.barrier() + # dist.barrier() # Last rank checks result - if self.rank == self.world_size - 1: + if self.rank == group_size - 1: ref_out = mod(x, y=y) ref_loss = loss_fn(ref_out, target) pipe_loss = sum(losses) @@ -207,9 +209,8 @@ def test_kwargs_with_tracer(self, ScheduleClass): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) - @parametrize("ModelClass", [MultiMLP]) - def test_grad_with_tracer(self, ScheduleClass, ModelClass): - mod = ModelClass(d_hid) + def test_grad_with_tracer(self, ScheduleClass): + mod = MultiMLP(d_hid, n_layers=self.world_size) mod.to(self.device) ref_mod = copy.deepcopy(mod) @@ -229,7 +230,7 @@ def test_grad_with_tracer(self, ScheduleClass, ModelClass): ref_loss.backward() # Create a pipeline - chunks = 4 + chunks = 2 * self.world_size x_mb = x.chunk(chunks)[0] split_spec = mod.split_spec if hasattr(mod, "split_spec") else None pipe = pipeline( @@ -307,7 +308,7 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference): # Get a submodule, e.g. `layers.0` or `layers.1` submod_name = f"layers.{self.rank}" stage_module = full_mod.get_submodule(submod_name) - chunks = 4 + chunks = 2 * self.world_size if shape_inference: input_args = None @@ -410,7 +411,7 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): num_microbatches = ( ScheduleClass.num_microbatches if hasattr(ScheduleClass, "num_microbatches") - else 8 + else 2 * self.world_size ) stages = [ PipelineStage( @@ -518,13 +519,15 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): raise @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @skip_but_pass_in_sandcastle_if( + not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs" + ) @parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble]) def test_schedule_with_native_zero_bubble(self, ScheduleClass): print(ScheduleClass) if ScheduleClass is ScheduleInterleavedZeroBubble: n_stages = 4 - num_microbatches = 8 + num_microbatches = 2 * n_stages rank_stages = { 0: [0, 2], 1: [1, 3], @@ -612,7 +615,9 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass): raise @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @skip_but_pass_in_sandcastle_if( + not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs" + ) @parametrize( "ScheduleClass", [ @@ -717,7 +722,9 @@ def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): raise @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @skip_but_pass_in_sandcastle_if( + not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs" + ) @parametrize( "schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble] ) @@ -822,7 +829,9 @@ def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime): raise @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @skip_but_pass_in_sandcastle_if( + not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs" + ) @parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble]) def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages_per_rank = 2 @@ -942,30 +951,4 @@ def dw_runner(): if __name__ == "__main__": - # Check if GPU and NCCL are available - if not ( - dist.is_available() - and dist.is_nccl_available() - and torch.cuda.device_count() > 1 - ): - print( - "c10d NCCL not available or not enough GPUs, skipping tests", - file=sys.stderr, - ) - sys.exit(0) - - rank = int(os.getenv("RANK", -1)) - world_size = int(os.getenv("WORLD_SIZE", 2)) - - if rank != -1: - # Launched with torchrun or other multi-proc launchers. Directly run the test. - ScheduleTest.run_rank(rank, world_size) - else: - # Launched as a single process. Spawn subprocess to run the tests. - # Also need a rendezvous file for `init_process_group` purpose. - rdvz_file = tempfile.NamedTemporaryFile(delete=False).name - torch.multiprocessing.spawn( - ScheduleTest.run_rank, - nprocs=world_size, - args=(world_size, rdvz_file), - ) + run_tests() diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index cb643ecbf72ae..5ef0ec84fc0eb 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -1,8 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] + import os -import sys -import tempfile from model_registry import ExampleCode, ModelWithKwargs, MultiMLP @@ -18,11 +17,14 @@ from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, + MultiProcessTestCase, requires_nccl, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + run_tests, + skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle_if, ) from torch.utils._pytree import tree_map_only @@ -32,6 +34,8 @@ batch_size = 256 chunks = 4 +device_type = "cuda" + torch.manual_seed(0) @@ -66,20 +70,18 @@ def backend_str(cls) -> str: return "nccl" @classmethod - def setUpClass(cls): - """ - Class-scope test fixture. Run once for entire test class, before any test starts. - Set up the device. - """ - super().setUpClass() - dev_id = cls.rank % torch.cuda.device_count() - cls.device = torch.device(f"cuda:{dev_id}") + def device_type(cls) -> str: + return device_type + + @property + def device(self) -> torch.device: + return torch.device(device_type, self.rank) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ModelClass", [ExampleCode, MultiMLP]) def test_tracer(self, ModelClass): - mod = ModelClass(d_hid) + mod = ModelClass(d_hid, self.world_size) mod.to(self.device) x = torch.randn(batch_size, d_hid, device=self.device) @@ -119,32 +121,11 @@ def _run_step(x): old_keys = mod.state_dict().keys() assert all(k in old_keys for k in submod_keys) - if self.rank == 0: - # intended to run this code on all ranks, but the problem is if rank0 throws, - # it won't perform the send that unblocks rank 1. - - # TODO(whc) can't test this until fixing args/kwargs issue - # with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): - # _run_step(torch.randn(batch_size + 1, d_hid, device=self.device)) - - with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): - _run_step(x.to(torch.int32)) - - # output of stage's mlp layer will be flattened by this hook, the stage should err - handle = stage.submod.register_forward_hook(get_flatten_hook()) - with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): - _run_step(x) - handle.remove() - - stage.submod.register_forward_hook(get_dtype_change_hook(torch.bfloat16)) - with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): - _run_step(x) - @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ModelClass", [ModelWithKwargs]) def test_tracer_kwargs(self, ModelClass): - mod = ModelClass(d_hid) + mod = ModelClass(d_hid, self.world_size) mod.to(self.device) x = torch.randn(batch_size, d_hid, device=self.device) @@ -221,23 +202,6 @@ def _run_step(x): ref_out = full_mod(x) torch.testing.assert_close(out, ref_out) - if self.rank == 0: - with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): - _run_step(torch.randn(batch_size + 1, d_hid, device=self.device)) - - with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): - _run_step(x.to(torch.int32)) - - # output of stage's mlp layer will be flattened by this hook, the stage should err - handle = stage_mod.register_forward_hook(get_flatten_hook()) - with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): - _run_step(x) - handle.remove() - - stage_mod.register_forward_hook(get_dtype_change_hook(torch.bfloat16)) - with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): - _run_step(x) - @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_custom_dw_with_fb_schedule(self): @@ -298,28 +262,6 @@ def _run_step(x): ref_out = full_mod(x) torch.testing.assert_close(out, ref_out) - if self.rank == 0: - with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): - _run_step(torch.randn(batch_size + 1, d_hid, device=self.device)) - - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - def test_custom_dw_errors(self): - """Tests expected errors are raised""" - full_mod = MultiMLP(d_hid, n_layers=self.world_size) - full_mod.to(self.device) - stage_mod = full_mod.get_submodule(f"layers.{self.rank}") - - stage_with_dw_builder = PipelineStage( - stage_mod, - self.rank, - self.world_size, - self.device, - dw_builder=lambda: None, - ) - with self.assertRaisesRegex(AssertionError, "backward_one_chunk"): - stage_with_dw_builder.backward_weight_one_chunk(bwd_chunk_id=0) - @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_output_chunks_memory_usage(self): @@ -381,31 +323,105 @@ def _run_step(x): instantiate_parametrized_tests(StageTest) -if __name__ == "__main__": - # Check if GPU and NCCL are available - if not ( - dist.is_available() - and dist.is_nccl_available() - and torch.cuda.device_count() > 1 - ): - print( - "c10d NCCL not available or not enough GPUs, skipping tests", - file=sys.stderr, + +class StageNegativeTest(MultiProcessTestCase): + @property + def world_size(self) -> int: + return torch.get_device_module(device_type).device_count() + + @property + def device(self) -> torch.device: + return torch.device(device_type, self.rank) + + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + def init_pg(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + store=store, + rank=self.rank, + world_size=self.world_size, + device_id=self.device, ) - sys.exit(0) - - rank = int(os.getenv("RANK", -1)) - world_size = int(os.getenv("WORLD_SIZE", 2)) - - if rank != -1: - # Launched with torchrun or other multi-proc launchers. Directly run the test. - StageTest.run_rank(rank, world_size) - else: - # Launched as a single process. Spawn subprocess to run the tests. - # Also need a rendezvous file for `init_process_group` purpose. - rdvz_file = tempfile.NamedTemporaryFile(delete=False).name - torch.multiprocessing.spawn( - StageTest.run_rank, - nprocs=world_size, - args=(world_size, rdvz_file), + + @requires_nccl() + @skip_but_pass_in_sandcastle("Flaky in CI") + def test_shape_prop_mismatch(self): + """Tests shape prop errors are raised""" + self.init_pg() + + full_mod = MultiMLP(d_hid, n_layers=self.world_size) + full_mod.to(self.device) + stage_mod = full_mod.get_submodule(f"layers.{self.rank}") + + x = torch.randn(batch_size, d_hid, device=self.device) + + stage = PipelineStage( + stage_mod, + self.rank, + self.world_size, + self.device, ) + + # Attach to a schedule + schedule = ScheduleGPipe(stage, chunks) + + # Run + def _run_step(x): + if self.rank == 0: + return schedule.step(x) + else: + return schedule.step() + + _run_step(x) + + if self.rank == 0: + with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): + _run_step(torch.randn(batch_size + 1, d_hid, device=self.device)) + + with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): + _run_step(x.to(torch.int32)) + + # output of stage's mlp layer will be flattened by this hook, the stage should err + handle = stage_mod.register_forward_hook(get_flatten_hook()) + with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): + _run_step(x) + handle.remove() + + stage_mod.register_forward_hook(get_dtype_change_hook(torch.bfloat16)) + with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): + _run_step(x) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_custom_dw_errors(self): + """Tests expected errors are raised""" + self.init_pg() + + full_mod = MultiMLP(d_hid, n_layers=self.world_size) + full_mod.to(self.device) + stage_mod = full_mod.get_submodule(f"layers.{self.rank}") + + stage_with_dw_builder = PipelineStage( + stage_mod, + self.rank, + self.world_size, + self.device, + dw_builder=lambda: None, + ) + with self.assertRaisesRegex(AssertionError, "backward_one_chunk"): + stage_with_dw_builder.backward_weight_one_chunk(bwd_chunk_id=0) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/test_c10d_ops_nccl.py b/test/distributed/test_c10d_ops_nccl.py index 540be51bdb392..068fe606dd9d0 100644 --- a/test/distributed/test_c10d_ops_nccl.py +++ b/test/distributed/test_c10d_ops_nccl.py @@ -11,7 +11,6 @@ import math import os import sys -import tempfile import torch import torch.distributed as c10d @@ -30,9 +29,9 @@ requires_nccl, requires_nccl_version, sm_is_or_higher_than, - TEST_SKIPS, ) from torch.testing._internal.common_utils import ( + run_tests, skip_but_pass_in_sandcastle_if, skipIfRocm, TEST_WITH_DEV_DBG_ASAN, @@ -1044,24 +1043,4 @@ def allgather_base(output_t, input_t): if __name__ == "__main__": - if not torch.cuda.is_available(): - sys.exit(TEST_SKIPS["no_cuda"].exit_code) - - rank = int(os.getenv("RANK", -1)) - world_size = int(os.getenv("WORLD_SIZE", -1)) - - if world_size == -1: # Not set by external launcher - world_size = torch.cuda.device_count() - - if rank != -1: - # Launched with torchrun or other multi-proc launchers. Directly run the test. - ProcessGroupNCCLOpTest.run_rank(rank, world_size) - else: - # Launched as a single process. Spawn subprocess to run the tests. - # Also need a rendezvous file for `init_process_group` purpose. - rdvz_file = tempfile.NamedTemporaryFile(delete=False).name - torch.multiprocessing.spawn( - ProcessGroupNCCLOpTest.run_rank, - nprocs=world_size, - args=(world_size, rdvz_file), - ) + run_tests() diff --git a/test/distributed/test_composability.py b/test/distributed/test_composability.py index 7b66f81225ed6..7369d938441b3 100644 --- a/test/distributed/test_composability.py +++ b/test/distributed/test_composability.py @@ -1,11 +1,7 @@ # Owner(s): ["oncall: distributed"] import copy -import os -import sys -import tempfile import torch -import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.distributed.device_mesh import init_device_mesh @@ -30,11 +26,15 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + run_tests, skip_but_pass_in_sandcastle_if, TEST_WITH_ROCM, ) +device_type = "cuda" + + # MLP Layer class MLPModule(torch.nn.Module): def __init__(self, d_hid: int): @@ -92,29 +92,14 @@ def loss_fn(y, target, scale=1e-4): class ComposabilityTest(MultiProcContinousTest): - world_size = 4 - @classmethod def backend_str(cls) -> str: # Testing with NCCL backend return "nccl" - @classmethod - def setUpClass(cls): - """ - Class-scope test fixture. Run once for entire test class, before any test starts. - Set up the device. - """ - super().setUpClass() - dev_id = cls.rank % torch.cuda.device_count() - cls.device = torch.device(f"cuda:{dev_id}") - torch.cuda.set_device(cls.device) - - def _build_mesh(self, mesh_shape=(2, 2), mesh_dim_names=("dp", "pp")): - device_mesh = init_device_mesh( - "cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names - ) - return device_mesh + @property + def device(self) -> torch.device: + return torch.device(device_type, self.rank) def _rand_microbatches(self, dp_mesh, num_microbatches, dim, dtype=torch.float32): full = [ @@ -216,7 +201,12 @@ def test_pp_ddp(self, ScheduleClass): # https://github.com/pytorch/pytorch/issues/144530 return - device_mesh = self._build_mesh((2, 2), ("dp", "pp")) + torch.get_device_module(device_type).set_device(self.device) + mesh_shape = (self.world_size // 2, 2) + mesh_dim_names = ("dp", "pp") + device_mesh = init_device_mesh( + "cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names + ) pp_group = device_mesh["pp"].get_group() dp_mesh = device_mesh["dp"] @@ -292,7 +282,12 @@ def test_pp_fsdp(self, dp_type, ScheduleClass): if TEST_WITH_ROCM: return - device_mesh = self._build_mesh((2, 2), ("dp", "pp")) + torch.get_device_module(device_type).set_device(self.device) + mesh_shape = (self.world_size // 2, 2) + mesh_dim_names = ("dp", "pp") + device_mesh = init_device_mesh( + "cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names + ) pp_group = device_mesh["pp"].get_group() dp_mesh = device_mesh["dp"] @@ -376,35 +371,12 @@ def apply_dp(partial_model): name = ".".join(parts) ref_p = ref_parameters[name] self.assertTrue(isinstance(p.grad, DTensor)) - torch.testing.assert_close(p.grad.full_tensor(), ref_p.grad) + torch.testing.assert_close( + p.grad.full_tensor(), ref_p.grad, atol=5e-5, rtol=2e-2 + ) instantiate_parametrized_tests(ComposabilityTest) + if __name__ == "__main__": - # Check if GPU and NCCL are available - if not ( - dist.is_available() - and dist.is_nccl_available() - and torch.cuda.device_count() > 3 - ): - print( - "c10d NCCL not available or not enough GPUs, skipping tests", - file=sys.stderr, - ) - sys.exit(0) - - rank = int(os.getenv("RANK", -1)) - world_size = int(os.getenv("WORLD_SIZE", 4)) - - if rank != -1: - # Launched with torchrun or other multi-proc launchers. Directly run the test. - ComposabilityTest.run_rank(rank, world_size) - else: - # Launched as a single process. Spawn subprocess to run the tests. - # Also need a rendezvous file for `init_process_group` purpose. - rdvz_file = tempfile.NamedTemporaryFile(delete=False).name - torch.multiprocessing.spawn( - ComposabilityTest.run_rank, - nprocs=world_size, - args=(world_size, rdvz_file), - ) + run_tests() diff --git a/test/distributed/test_nvshmem.py b/test/distributed/test_nvshmem.py index de3a394062d68..e56f55efa32ad 100644 --- a/test/distributed/test_nvshmem.py +++ b/test/distributed/test_nvshmem.py @@ -7,16 +7,13 @@ import os import sys -import tempfile import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem -from torch.testing._internal.common_distributed import ( - MultiProcContinousTest, - TEST_SKIPS, -) +from torch.testing._internal.common_distributed import MultiProcContinousTest from torch.testing._internal.common_utils import ( + run_tests, skip_but_pass_in_sandcastle_if, skipIfRocm, ) @@ -47,28 +44,20 @@ def requires_nvshmem(): @requires_nvshmem() class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest): - def setUp(self) -> None: - super().setUp() + def _init_device(self) -> None: # TODO: relieve this (seems to hang if without) device_module.set_device(self.device) # NOTE: required for nvshmem allocation torch.empty(1, device=self.device) - # Required by MultiProcContinousTest - @classmethod - def backend_str(cls) -> str: - return "nccl" - - @property - def world_size(self) -> int: - return device_module.device_count() - @property def device(self) -> torch.device: return torch.device(device_type, self.rank) @skipIfRocm def test_nvshmem_all_to_all(self) -> None: + self._init_device() + group_name = dist.group.WORLD.group_name symm_mem.enable_symm_mem_for_group(group_name) @@ -92,6 +81,8 @@ def test_nvshmem_all_to_all(self) -> None: @skipIfRocm def test_nvshmem_all_to_all_vdev(self) -> None: + self._init_device() + group_name = dist.group.WORLD.group_name symm_mem.enable_symm_mem_for_group(group_name) @@ -139,24 +130,4 @@ def test_nvshmem_all_to_all_vdev(self) -> None: if __name__ == "__main__": - if not device_module.is_available(): - sys.exit(TEST_SKIPS["no_cuda"].exit_code) - - # If launched by torchrun, these values would have been set - rank = int(os.getenv("RANK", "-1")) - world_size = int(os.getenv("WORLD_SIZE", "-1")) - - if rank != -1: - # Launched with torchrun or other multi-proc launchers. Directly run the test. - NVSHMEMSymmetricMemoryTest.run_rank(rank, world_size) - else: - # No external launcher, spawn N processes - world_size = device_module.device_count() - # Launched as a single process. Spawn subprocess to run the tests. - # Also need a rendezvous file for `init_process_group` purpose. - rdvz_file = tempfile.NamedTemporaryFile(delete=False).name - torch.multiprocessing.spawn( - NVSHMEMSymmetricMemoryTest.run_rank, - nprocs=world_size, - args=(world_size, rdvz_file), - ) + run_tests() diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 28c186486a7fa..32c952110a125 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -1,6 +1,5 @@ # mypy: ignore-errors -import abc import faulthandler import itertools import logging @@ -38,7 +37,6 @@ find_free_port, IS_SANDCASTLE, retry_on_connect_failures, - run_tests, skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle_if, TEST_HPU, @@ -681,6 +679,11 @@ def _start_processes(self, proc) -> None: self.processes.append(process) def _spawn_processes(self) -> None: + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + pass + proc = torch.multiprocessing.get_context("spawn").Process self._start_processes(proc) @@ -1502,24 +1505,34 @@ def _run( class MultiProcContinousTest(TestCase): # Class variables: + MAIN_PROCESS_RANK = -1 # number of test processes - world_size: int = 2 + world_size: int = -2 # unset state # rank of the current process - rank: int = -1 # unset state + rank: int = -2 # unset state # Rendezvous file rdvz_file: Optional[str] = None # timeout configured per class timeout: timedelta = timedelta(seconds=120) + # Poison pill for rest of tests if one of them fails + poison_pill: bool = False @classmethod - @abc.abstractmethod - def backend_str(cls) -> str: + def backend_str(cls) -> Optional[str]: """ ProcessGroup backend str. To be customized by sub test classes, e.g. "nccl". - Here we raise error. + Otherwise we return None -- lazily decided by tensor. """ - raise NotImplementedError("Please implement backend_str in your test class") + return None + + # Please override if you intend to test on specific device type + @classmethod + def device_type(cls) -> str: + curr_device = torch.accelerator.current_accelerator() + if curr_device is None: + return "cpu" + return curr_device.type @classmethod def opts(cls, high_priority_stream=False): @@ -1530,6 +1543,101 @@ def opts(cls, high_priority_stream=False): """ return None + @classmethod + def _init_pg(cls, rank, world_size, rdvz_file): + assert rdvz_file is not None + store = c10d.FileStore(rdvz_file, world_size) + + # create nccl processgroup with opts + c10d.init_process_group( + backend=cls.backend_str(), + world_size=world_size, + rank=rank, + store=store, + pg_options=cls.opts(), + timeout=cls.timeout, + ) + cls.pg = c10d.distributed_c10d._get_default_group() + + @classmethod + def _run_test_given_id(cls, test_id: str, **kwargs) -> None: + # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' + test_name = test_id.split(".")[-1] + # Get the test function from the test class + self = cls(test_name) + self.rank = cls.rank + self.world_size = cls.world_size + test_fn = getattr(self, test_name) + # Run the test function + test_fn(**kwargs) + + @classmethod + def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue): + # Sub tests are going to access these values, check first + assert 0 <= rank < world_size + # set class variables for the test class + cls.rank = rank + cls.world_size = world_size + + # Initialize the process group + cls._init_pg(rank, world_size, rdvz_file) + + # End of bootstrap + logger.info("Setup complete") + + # Loop forever, waiting for a test name to run + while True: + test_id = task_queue.get() + logger.debug(f"Got test {test_id}") # noqa: G004 + # None means exit + if test_id is None: + break + + # Run the test + try: + cls._run_test_given_id(test_id) + completion_queue.put(test_id) + except BaseException as ex: + # Send the exception back to the dispatcher + completion_queue.put(ex) + + # Termination + logger.info("Terminating ...") + c10d.destroy_process_group() + + @classmethod + def _spawn_processes(cls, world_size) -> None: + cls.processes = [] + cls.task_queues = [] + cls.completion_queues = [] + # Need a rendezvous file for `init_process_group` purpose. + cls.rdvz_file = tempfile.NamedTemporaryFile(delete=False).name + + # CUDA multiprocessing requires spawn instead of fork, to make sure + # child processes have their own memory space. + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + # The start method has already been set + pass + + for rank in range(int(world_size)): + task_queue = torch.multiprocessing.Queue() + completion_queue = torch.multiprocessing.Queue() + process = torch.multiprocessing.Process( + target=cls._worker_loop, + name="process " + str(rank), + daemon=True, # so that child processes will exit if parent decides to terminate + args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue), + ) + process.start() + cls.processes.append(process) + cls.task_queues.append(task_queue) + cls.completion_queues.append(completion_queue) + logger.info( + "Started process %s with pid %s", rank, process.pid + ) # noqa: UP031 + @classmethod def setUpClass(cls): """ @@ -1537,30 +1645,18 @@ def setUpClass(cls): Set up the process group. """ super().setUpClass() - if not 0 <= cls.rank < cls.world_size: - raise RuntimeError( - "Rank must be set and in the range of 0 to world_size. " - f"World size: {cls.world_size} Rank: {cls.rank}" - ) - if cls.rdvz_file: - store = c10d.FileStore(cls.rdvz_file, cls.world_size) - else: - # torchrun takes care of rendezvous - store = None - opts = cls.opts() - backend = cls.backend_str() - print(f"Testing {backend=}") - # create nccl processgroup with opts - c10d.init_process_group( - backend=backend, - world_size=cls.world_size, - rank=cls.rank, - store=store, - pg_options=opts, - timeout=cls.timeout, + + # Use device count as world size + device_type = cls.device_type() + cls.world_size = torch.get_device_module(device_type).device_count() + if cls.world_size == 0: + raise unittest.SkipTest(f"No {device_type} devices available") + + logger.info( + f"Testing class {cls.__name__} on {cls.world_size} {device_type}" # noqa: G004 ) - cls.pg = c10d.distributed_c10d._get_default_group() - print(f"Rank {cls.rank} setup complete") + + cls._spawn_processes(cls.world_size) @classmethod def tearDownClass(cls): @@ -1568,37 +1664,91 @@ def tearDownClass(cls): Class-scope test fixture. Run once for entire test class, after all tests finish. Tear down the process group. """ - c10d.destroy_process_group() - super().tearDownClass() + logger.debug(f"Joining {cls.world_size} workers") # noqa: G004 + # Enqueue "None" to all workers to tell them to exit + for task_queue in cls.task_queues: + task_queue.put(None) + + # Wait for all workers to exit + for process in cls.processes: + process.join() + # Clear up the rendezvous file - if cls.rdvz_file: - try: - os.remove(cls.rdvz_file) - except OSError: - pass - print(f"Rank {cls.rank} teardown complete") + try: + os.remove(cls.rdvz_file) + except OSError: + pass - @classmethod - def run_rank( - cls, - rank: int, - world_size: int, - rdvz_file: Optional[str] = None, - ): + logger.info(f"Class {cls.__name__} finished") # noqa: G004 + super().tearDownClass() + + def setUp(self) -> None: + """ + Test fixture. Run before each test. """ - This is an entry point for each rank to run the tests in `MultiProcContinousTest`. - In this entry point, we set the class variables for the test class. - Then we run all tests. + super().setUp() - Note: - - This helper only works for a subclass of `MultiProcContinousTest`. + # I am the dispatcher + self.rank = self.MAIN_PROCESS_RANK - Example: - - See `test_c10d_ops_nccl.py`. - """ - # set class variables for the test class - cls.rank = rank - cls.world_size = world_size - cls.rdvz_file = rdvz_file - # Launch tests via `common_utils` infra - run_tests() + # If this test class hits an exception in one test, skip the rest of tests + if self.__class__.poison_pill: + raise unittest.SkipTest(f"Previous test failed, skipping {self.id()}") + + # Enqueue "current test" to all workers + for i, task_queue in enumerate(self.task_queues): + logger.debug(f"Sending Rank {i}: {self.id()}") # noqa: G004 + task_queue.put(self.id()) + + def _worker_run_main_wait(self, fn): + @wraps(fn) + def wrapper(self): + if self.rank == self.MAIN_PROCESS_RANK: + logger.debug(f"Waiting for workers to finish {self.id()}") # noqa: G004 + # Wait for the workers to finish the test + for i, completion_queue in enumerate(self.completion_queues): + rv = completion_queue.get() + if isinstance(rv, BaseException): + # Hit an exception, re-raise it in the main process. + logger.warning( + f"Detected failure from Rank {i} in: {self.id()}, " # noqa: G004 + f"skipping rest of tests in Test class: {self.__class__.__name__}" # noqa: G004 + ) + # Poison rest of tests (because ProcessGroup may be not + # re-usable now) + self.__class__.poison_pill = True + raise rv + + # Success + assert rv == self.id() + logger.debug( + f"Main proc detected rank {i} finished {self.id()}" # noqa: G004 + ) + else: + # Worker just runs the test + fn() + + return types.MethodType(wrapper, self) + + # The main process spawns N subprocesses that run the test. + # Constructor patches current instance test method to + # assume the role of the main process and join its subprocesses, + # or run the underlying test function. + def __init__( + self, method_name: str = "runTest", methodName: str = "runTest" + ) -> None: + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName + super().__init__(method_name) + try: + fn = getattr(self, method_name) + setattr(self, method_name, self._worker_run_main_wait(fn)) + except AttributeError as e: + if methodName != "runTest": + # we allow instantiation with no explicit method name + # but not an *incorrect* or missing method name + raise ValueError( + f"no such test method in {self.__class__}: {methodName}" + ) from e