22# Owner(s): ["oncall: distributed"]
33import copy
44import logging
5- import os
6- import sys
75import tempfile
86
97from model_registry import ModelWithKwargs , MultiMLP , MultiMLPWithDw
3735 check_leaked_tensors ,
3836 instantiate_parametrized_tests ,
3937 parametrize ,
38+ run_tests ,
4039 skip_but_pass_in_sandcastle_if ,
4140)
4241
4847
4948torch .manual_seed (0 )
5049
50+ device_type = "cuda"
51+
5152
5253class ScheduleTest (MultiProcContinousTest ):
5354 @classmethod
5455 def backend_str (cls ) -> str :
5556 # Testing with NCCL backend
5657 return "nccl"
5758
58- @classmethod
59- def setUpClass (cls ):
60- """
61- Class-scope test fixture. Run once for entire test class, before any test starts.
62- Set up the device.
63- """
64- super ().setUpClass ()
65- dev_id = cls .rank % torch .cuda .device_count ()
66- cls .device = torch .device (f"cuda:{ dev_id } " )
59+ @property
60+ def device (self ) -> torch .device :
61+ return torch .device (device_type , self .rank )
6762
6863 @requires_nccl ()
6964 @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "NCCL test requires 2+ GPUs" )
@@ -77,7 +72,7 @@ def test_forward_only(self, ScheduleClass):
7772 x = torch .randn (batch_size , d_hid , device = self .device )
7873 x_clone = x .clone ()
7974
80- num_microbatches = 4
75+ num_microbatches = 2 * self . world_size
8176 x_mb = x .chunk (num_microbatches )[0 ]
8277
8378 # Create a pipeline
@@ -159,6 +154,12 @@ def test_multi_iter(self, ScheduleClass):
159154 @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "NCCL test requires 2+ GPUs" )
160155 @parametrize ("ScheduleClass" , [ScheduleGPipe , Schedule1F1B ])
161156 def test_kwargs_with_tracer (self , ScheduleClass ):
157+ # Model has two stages only, thus limiting group size to 2
158+ group_size = 2
159+ group = dist .new_group (list (range (group_size )))
160+ if self .rank >= group_size :
161+ return
162+
162163 mod = ModelWithKwargs (d_hid )
163164 mod .to (self .device )
164165
@@ -180,6 +181,7 @@ def test_kwargs_with_tracer(self, ScheduleClass):
180181 stage = pipe .build_stage (
181182 self .rank ,
182183 self .device ,
184+ group = group ,
183185 )
184186
185187 # Attach to a schedule
@@ -188,16 +190,16 @@ def test_kwargs_with_tracer(self, ScheduleClass):
188190 # Run
189191 if self .rank == 0 :
190192 schedule .step (x , y = y )
191- elif self .rank == self . world_size - 1 :
193+ elif self .rank == group_size - 1 :
192194 losses = []
193195 out = schedule .step (target = target , losses = losses )
194196 else :
195197 schedule .step ()
196198
197- dist .barrier ()
199+ # dist.barrier()
198200
199201 # Last rank checks result
200- if self .rank == self . world_size - 1 :
202+ if self .rank == group_size - 1 :
201203 ref_out = mod (x , y = y )
202204 ref_loss = loss_fn (ref_out , target )
203205 pipe_loss = sum (losses )
@@ -207,9 +209,8 @@ def test_kwargs_with_tracer(self, ScheduleClass):
207209 @requires_nccl ()
208210 @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "NCCL test requires 2+ GPUs" )
209211 @parametrize ("ScheduleClass" , [ScheduleGPipe , Schedule1F1B ])
210- @parametrize ("ModelClass" , [MultiMLP ])
211- def test_grad_with_tracer (self , ScheduleClass , ModelClass ):
212- mod = ModelClass (d_hid )
212+ def test_grad_with_tracer (self , ScheduleClass ):
213+ mod = MultiMLP (d_hid , n_layers = self .world_size )
213214 mod .to (self .device )
214215
215216 ref_mod = copy .deepcopy (mod )
@@ -229,7 +230,7 @@ def test_grad_with_tracer(self, ScheduleClass, ModelClass):
229230 ref_loss .backward ()
230231
231232 # Create a pipeline
232- chunks = 4
233+ chunks = 2 * self . world_size
233234 x_mb = x .chunk (chunks )[0 ]
234235 split_spec = mod .split_spec if hasattr (mod , "split_spec" ) else None
235236 pipe = pipeline (
@@ -307,7 +308,7 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference):
307308 # Get a submodule, e.g. `layers.0` or `layers.1`
308309 submod_name = f"layers.{ self .rank } "
309310 stage_module = full_mod .get_submodule (submod_name )
310- chunks = 4
311+ chunks = 2 * self . world_size
311312
312313 if shape_inference :
313314 input_args = None
@@ -410,7 +411,7 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
410411 num_microbatches = (
411412 ScheduleClass .num_microbatches
412413 if hasattr (ScheduleClass , "num_microbatches" )
413- else 8
414+ else 2 * self . world_size
414415 )
415416 stages = [
416417 PipelineStage (
@@ -518,13 +519,15 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
518519 raise
519520
520521 @requires_nccl ()
521- @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "NCCL test requires 2+ GPUs" )
522+ @skip_but_pass_in_sandcastle_if (
523+ not torch .cuda .device_count () == 2 , "This test requires exactly 2 GPUs"
524+ )
522525 @parametrize ("ScheduleClass" , [ScheduleWithW , ScheduleInterleavedZeroBubble ])
523526 def test_schedule_with_native_zero_bubble (self , ScheduleClass ):
524527 print (ScheduleClass )
525528 if ScheduleClass is ScheduleInterleavedZeroBubble :
526529 n_stages = 4
527- num_microbatches = 8
530+ num_microbatches = 2 * n_stages
528531 rank_stages = {
529532 0 : [0 , 2 ],
530533 1 : [1 , 3 ],
@@ -612,7 +615,9 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass):
612615 raise
613616
614617 @requires_nccl ()
615- @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "NCCL test requires 2+ GPUs" )
618+ @skip_but_pass_in_sandcastle_if (
619+ not torch .cuda .device_count () == 2 , "This test requires exactly 2 GPUs"
620+ )
616621 @parametrize (
617622 "ScheduleClass" ,
618623 [
@@ -717,7 +722,9 @@ def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
717722 raise
718723
719724 @requires_nccl ()
720- @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "NCCL test requires 2+ GPUs" )
725+ @skip_but_pass_in_sandcastle_if (
726+ not torch .cuda .device_count () == 2 , "This test requires exactly 2 GPUs"
727+ )
721728 @parametrize (
722729 "schedule_class" , [ScheduleVShaped , ScheduleUnbalanced , ScheduleZBVZeroBubble ]
723730 )
@@ -822,7 +829,9 @@ def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime):
822829 raise
823830
824831 @requires_nccl ()
825- @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "NCCL test requires 2+ GPUs" )
832+ @skip_but_pass_in_sandcastle_if (
833+ not torch .cuda .device_count () == 2 , "This test requires exactly 2 GPUs"
834+ )
826835 @parametrize ("ScheduleClass" , [ScheduleInterleavedZeroBubble ])
827836 def test_schedule_with_weight_update_mlp_e2e (self , ScheduleClass ):
828837 stages_per_rank = 2
@@ -942,30 +951,4 @@ def dw_runner():
942951
943952
944953if __name__ == "__main__" :
945- # Check if GPU and NCCL are available
946- if not (
947- dist .is_available ()
948- and dist .is_nccl_available ()
949- and torch .cuda .device_count () > 1
950- ):
951- print (
952- "c10d NCCL not available or not enough GPUs, skipping tests" ,
953- file = sys .stderr ,
954- )
955- sys .exit (0 )
956-
957- rank = int (os .getenv ("RANK" , - 1 ))
958- world_size = int (os .getenv ("WORLD_SIZE" , 2 ))
959-
960- if rank != - 1 :
961- # Launched with torchrun or other multi-proc launchers. Directly run the test.
962- ScheduleTest .run_rank (rank , world_size )
963- else :
964- # Launched as a single process. Spawn subprocess to run the tests.
965- # Also need a rendezvous file for `init_process_group` purpose.
966- rdvz_file = tempfile .NamedTemporaryFile (delete = False ).name
967- torch .multiprocessing .spawn (
968- ScheduleTest .run_rank ,
969- nprocs = world_size ,
970- args = (world_size , rdvz_file ),
971- )
954+ run_tests ()
0 commit comments