8000 [Distributed][CI] Rework continuous TestCase (#153653) · pytorch/pytorch@9d922b5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9d922b5

Browse files
kwen2501pytorchmergebot
authored andcommitted
[Distributed][CI] Rework continuous TestCase (#153653)
1. Reworked `MultiProcContinousTest` to spawn processes during `setUpClass` instead of `main` (so that we can support multiple TestClass'es in one file). 2. The child processes are now an infinite loop, monitoring test IDs passed from main process via a task queue. Reciprocally, the child processes inform the main process completion of a test via a completion queue. 3. Added a test template. Pull Request resolved: #153653 Approved by: https://github.com/d4l3k, https://github.com/fegin, https://github.com/fduwjj
1 parent 03e102d commit 9d922b5

File tree

8 files changed

+435
-327
lines changed

8 files changed

+435
-327
lines changed

test/distributed/_test_template.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Owner(s): ["oncall: distributed"]
2+
3+
from torch.testing._internal.common_distributed import MultiProcContinousTest
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
7+
class TestTemplate(MultiProcContinousTest):
8+
def testABC(self):
9+
print(f"rank {self.rank} of {self.world_size} testing ABC")
10+
11+
def testDEF(self):
12+
print(f"rank {self.rank} of {self.world_size} testing DEF")
13+
14+
15+
if __name__ == "__main__":
16+
run_tests()

test/distributed/pipelining/model_registry.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88

99
class ExampleCode(torch.nn.Module):
10-
def __init__(self, d_hid):
10+
def __init__(self, d_hid, splits=2):
11+
assert splits <= 4
1112
super().__init__()
13+
self.splits = splits
1214
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
1315
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
1416
self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False))
1517
self.lin0 = torch.nn.Linear(d_hid, d_hid)
1618
self.lin1 = torch.nn.Linear(d_hid, d_hid)
19+
self.lin2 = torch.nn.Linear(d_hid, d_hid)
1720

1821
def forward(self, x):
1922
x = torch.mm(x, self.mm_param0)
@@ -24,21 +27,31 @@ def forward(self, x):
2427
pipe_split()
2528
x = torch.relu(x) + a_constant
2629
x = torch.mm(x, self.mm_param1)
27-
x = self.lin1(x)
28-
x = torch.relu(x)
30+
if self.splits > 2:
31+
pipe_split()
32+
x = self.lin1(x)
33+
x = torch.relu(x)
34+
if self.splits > 3:
35+
pipe_split()
36+
x = self.lin2(x)
37+
x = torch.relu(x)
2938
return x
3039

3140

3241
class ModelWithKwargs(torch.nn.Module):
3342
DEFAULT_DHID = 512
3443
DEFAULT_BATCH_SIZE = 256
3544

36-
def __init__(self, d_hid: int = DEFAULT_DHID):
45+
def __init__(self, d_hid: int = DEFAULT_DHID, splits=2):
46+
assert splits <= 4
3747
super().__init__()
48+
self.splits = splits
3849
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
3950
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
4051
self.lin0 = torch.nn.Linear(d_hid, d_hid)
4152
self.lin1 = torch.nn.Linear(d_hid, d_hid)
53+
self.lin2 = torch.nn.Linear(d_hid, d_hid)
54+
self.lin3 = torch.nn.Linear(d_hid, d_hid)
4255

4356
def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
4457
x = torch.mm(x, self.mm_param0)
@@ -49,6 +62,14 @@ def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
4962
x = torch.mm(x, self.mm_param1)
5063
x = self.lin1(x)
5164
x = torch.relu(x)
65+
if self.splits > 2:
66+
pipe_split()
67+
x = self.lin2(x)
68+
x = torch.relu(x)
69+
if self.splits > 3:
70+
pipe_split()
71+
x = self.lin3(x)
72+
x = torch.relu(x)
5273
return x
5374

5475

test/distributed/pipelining/test_schedule_multiproc.py

Lines changed: 36 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# Owner(s): ["oncall: distributed"]
33
import copy
44
import logging
5-
import os
6-
import sys
75
import tempfile
86

97
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPWithDw
@@ -37,6 +35,7 @@
3735
check_leaked_tensors,
3836
instantiate_parametrized_tests,
3937
parametrize,
38+
run_tests,
4039
skip_but_pass_in_sandcastle_if,
4140
)
4241

@@ -48,22 +47,18 @@
4847

4948
torch.manual_seed(0)
5049

50+
device_type = "cuda"
51+
5152

5253
class ScheduleTest(MultiProcContinousTest):
5354
@classmethod
5455
def backend_str(cls) -> str:
5556
# Testing with NCCL backend
5657
return "nccl"
5758

58-
@classmethod
59-
def setUpClass(cls):
60-
"""
61-
Class-scope test fixture. Run once for entire test class, before any test starts.
62-
Set up the device.
63-
"""
64-
super().setUpClass()
65-
dev_id = cls.rank % torch.cuda.device_count()
66-
cls.device = torch.device(f"cuda:{dev_id}")
59+
@property
60+
def device(self) -> torch.device:
61+
return torch.device(device_type, self.rank)
6762

6863
@requires_nccl()
6964
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@@ -77,7 +72,7 @@ def test_forward_only(self, ScheduleClass):
7772
x = torch.randn(batch_size, d_hid, device=self.device)
7873
x_clone = x.clone()
7974

80-
num_microbatches = 4
75+
num_microbatches = 2 * self.world_size
8176
x_mb = x.chunk(num_microbatches)[0]
8277

8378
# Create a pipeline
@@ -159,6 +154,12 @@ def test_multi_iter(self, ScheduleClass):
159154
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
160155
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
161156
def test_kwargs_with_tracer(self, ScheduleClass):
157+
# Model has two stages only, thus limiting group size to 2
158+
group_size = 2
159+
group = dist.new_group(list(range(group_size)))
160+
if self.rank >= group_size:
161+
return
162+
162163
mod = ModelWithKwargs(d_hid)
163164
mod.to(self.device)
164165

@@ -180,6 +181,7 @@ def test_kwargs_with_tracer(self, ScheduleClass):
180181
stage = pipe.build_stage(
181182
self.rank,
182183
self.device,
184+
group=group,
183185
)
184186

185187
# Attach to a schedule
@@ -188,16 +190,16 @@ def test_kwargs_with_tracer(self, ScheduleClass):
188190
# Run
189191
if self.rank == 0:
190192
schedule.step(x, y=y)
191-
elif self.rank == self.world_size - 1:
193+
elif self.rank == group_size - 1:
192194
losses = []
193195
out = schedule.step(target=target, losses=losses)
194196
else:
195197
schedule.step()
196198

197-
dist.barrier()
199+
# dist.barrier()
198200

199201
# Last rank checks result
200-
if self.rank == self.world_size - 1:
202+
if self.rank == group_size - 1:
201203
ref_out = mod(x, y=y)
202204
ref_loss = loss_fn(ref_out, target)
203205
pipe_loss = sum(losses)
@@ -207,9 +209,8 @@ def test_kwargs_with_tracer(self, ScheduleClass):
207209
@requires_nccl()
208210
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
209211
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
210-
@parametrize("ModelClass", [MultiMLP])
211-
def test_grad_with_tracer(self, ScheduleClass, ModelClass):
212-
mod = ModelClass(d_hid)
212+
def test_grad_with_tracer(self, ScheduleClass):
213+
mod = MultiMLP(d_hid, n_layers=self.world_size)
213214
mod.to(self.device)
214215

215216
ref_mod = copy.deepcopy(mod)
@@ -229,7 +230,7 @@ def test_grad_with_tracer(self, ScheduleClass, ModelClass):
229230
ref_loss.backward()
230231

231232
# Create a pipeline
232-
chunks = 4
233+
chunks = 2 * self.world_size
233234
x_mb = x.chunk(chunks)[0]
234235
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
235236
pipe = pipeline(
@@ -307,7 +308,7 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference):
307308
# Get a submodule, e.g. `layers.0` or `layers.1`
308309
submod_name = f"layers.{self.rank}"
309310
stage_module = full_mod.get_submodule(submod_name)
310-
chunks = 4
311+
chunks = 2 * self.world_size
311312

312313
if shape_inference:
313314
input_args = None
@@ -410,7 +411,7 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
410411
num_microbatches = (
411412
ScheduleClass.num_microbatches
412413
if hasattr(ScheduleClass, "num_microbatches")
413-
else 8
414+
else 2 * self.world_size
414415
)
415416
stages = [
416417
PipelineStage(
@@ -518,13 +519,15 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
518519
raise
519520

520521
@requires_nccl()
521-
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
522+
@skip_but_pass_in_sandcastle_if(
523+
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
524+
)
522525
@parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble])
523526
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
524527
print(ScheduleClass)
525528
if ScheduleClass is ScheduleInterleavedZeroBubble:
526529
n_stages = 4
527-
num_microbatches = 8
530+
num_microbatches = 2 * n_stages
528531
rank_stages = {
529532
0: [0, 2],
530533
1: [1, 3],
@@ -612,7 +615,9 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass):
612615
raise
613616

614617
@requires_nccl()
615-
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
618+
@skip_but_pass_in_sandcastle_if(
619+
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
620+
)
616621
@parametrize(
617622
"ScheduleClass",
618623
[
@@ -717,7 +722,9 @@ def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
717722
raise
718723

719724
@requires_nccl()
720-
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
725+
@skip_but_pass_in_sandcastle_if(
726+
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
727+
)
721728
@parametrize(
722729
"schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble]
723730
)
@@ -822,7 +829,9 @@ def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime):
822829
raise
823830

824831
@requires_nccl()
825-
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
832+
@skip_but_pass_in_sandcastle_if(
833+
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
834+
)
826835
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
827836
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
828837
stages_per_rank = 2
@@ -942,30 +951,4 @@ def dw_runner():
942951

943952

944953
if __name__ == "__main__":
945-
# Check if GPU and NCCL are available
946-
if not (
947-
dist.is_available()
948-
and dist.is_nccl_available()
949-
and torch.cuda.device_count() > 1
950-
):
951-
print(
952-
"c10d NCCL not available or not enough GPUs, skipping tests",
953-
file=sys.stderr,
954-
)
955-
sys.exit(0)
956-
957-
rank = int(os.getenv("RANK", -1))
958-
world_size = int(os.getenv("WORLD_SIZE", 2))
959-
960-
if rank != -1:
961-
# Launched with torchrun or other multi-proc launchers. Directly run the test.
962-
ScheduleTest.run_rank(rank, world_size)
963-
else:
964-
# Launched as a single process. Spawn subprocess to run the tests.
965-
# Also need a rendezvous file for `init_process_group` purpose.
966-
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
967-
torch.multiprocessing.spawn(
968-
ScheduleTest.run_rank,
969-
nprocs=world_size,
970-
args=(world_size, rdvz_file),
971-
)
954+
run_tests()

0 commit comments

Comments
 (0)
0