8000 [Distributed][CI] Rework continuous TestCase by kwen2501 · Pull Request #153653 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
kwen2501 committed May 23, 2025
commit 0ecc3d381ffa064190febdc56f5a650bbc758145
29 changes: 25 additions & 4 deletions test/distributed/pipelining/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@


class ExampleCode(torch.nn.Module):
def __init__(self, d_hid):
def __init__(self, d_hid, splits=2):
assert splits <= 4
super().__init__()
self.splits = splits
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False))
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)
self.lin2 = torch.nn.Linear(d_hid, d_hid)

def forward(self, x):
x = torch.mm(x, self.mm_param0)
Expand All @@ -24,21 +27,31 @@ def forward(self, x):
pipe_split()
x = torch.relu(x) + a_constant
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
if self.splits > 2:
pipe_split()
x = self.lin1(x)
x = torch.relu(x)
if self.splits > 3:
pipe_split()
x = self.lin2(x)
x = torch.relu(x)
return x


class ModelWithKwargs(torch.nn.Module):
DEFAULT_DHID = 512
DEFAULT_BATCH_SIZE = 256

def __init__(self, d_hid: int = DEFAULT_DHID):
def __init__(self, d_hid: int = DEFAULT_DHID, splits=2):
assert splits <= 4
super().__init__()
self.splits = splits
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)
self.lin2 = torch.nn.Linear(d_hid, d_hid)
self.lin3 = torch.nn.Linear(d_hid, d_hid)

def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
x = torch.mm(x, self.mm_param0)
Expand All @@ -49,6 +62,14 @@ def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
if self.splits > 2:
pipe_split()
x = self.lin2(x)
x = torch.relu(x)
if self.splits > 3:
pipe_split()
x = self.lin3(x)
x = torch.relu(x)
return x


Expand Down
89 changes: 36 additions & 53 deletions