8000 [FSDP2] Add set_reshard_after_forward (#149103) · pytorch/pytorch@5b8cc47 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5b8cc47

Browse files
mori360pytorchmergebot
authored andcommitted
[FSDP2] Add set_reshard_after_forward (#149103)
Fixes #149029 Add `set_reshard_after_forward` to set `post_forward_mesh_info` so as to decide `_reshard_after_forward` Add unit test similar to `test_fully_shard_communication_count`, the FSDPModule would perform as `._reshard_after_forward=True` after `.set_reshard_after_forward=True`, as well as setting to False Pull Request resolved: #149103 Approved by: https://github.com/awgu
1 parent a8df5e5 commit 5b8cc47

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

test/distributed/_composable/fsdp/test_fully_shard_comm.py

Lines changed: 58 additions & 0 deletions
10000
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,64 @@ def _test_set_reduce_scatter_divide_factor(self, divide_factor: float):
422422
self.assertEqual(ref_loss, loss)
423423
check_sharded_parity(self, ref_model, model)
424424

425+
@skip_if_lt_x_gpu(2)
426+
def test_set_reshard_after_forward(self):
427+
"""
428+
Tests that FSDP issues the expected number of all-gathers and
429+
reduce-scatters during a train step when setting reshard_after_forward.
430+
comm_count should perform same as test_fully_shard_communication_count.
431+
"""
432+
self.run_subtests(
433+
{"set_reshard_after_forward": [True, False], "recurse": [True, False]},
434+
self._test_set_reshard_after_forward_by_communication_count,
435+
)
436+
437+
def _test_set_reshard_after_forward_by_communication_count(
438+
self,
439+
set_reshard_after_forward: bool,
440+
recurse: bool,
441+
):
442+
torch.manual_seed(42)
443+
model_args = ModelArgs()
444+
model = Transformer(model_args)
445+
fully_shard_fn = functools.partial(
446+
fully_shard, reshard_after_forward=not set_reshard_after_forward
447+
)
448+
num_blocks = 0
449+
for module in model.modules():
450+
if isinstance(module, TransformerBlock):
451+
fully_shard_fn(module)
452+
num_blocks += 1
453+
fully_shard_fn(model)
454+
num_fsdp_modules = sum(
455+
isinstance(module, FSDPModule) for module in model.modules()
456+
)
457+
model.set_reshard_after_forward(
458+
reshard_after_forward=set_reshard_after_forward, recurse=recurse
459+
)
460+
461+
torch.manual_seed(42 + self.rank)
462+
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
463+
with CommDebugMode() as fwd_comm_mode:
464+
loss = model(inp)
465+
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
466+
self.assertEqual(len(fwd_comm_counts), 1)
467+
self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_fsdp_modules)
468+
469+
with CommDebugMode() as bwd_comm_mode:
470+
loss.sum().backward()
471+
bwd_comm_counts = bwd_comm_mode.get_comm_counts()
472+
# If recurse is False, set_reshard_after_forward only affects the root module,
473+
# resulting in comm_counts identical to those without set_reshard_after_forward.
474+
if recurse == set_reshard_after_forward:
475+
self.assertEqual(len(bwd_comm_counts), 2)
476+
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_blocks)
477+
else:
478+
self.assertEqual(len(bwd_comm_counts), 1)
479+
self.assertEqual(
480+
bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_blocks + 1
481+
)
482+
425483

426484
class TestFullyShardPrefetch(FSDPTest):
427485
@property

torch/distributed/fsdp/_fully_shard/_fully_shard.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,35 @@ def set_requires_all_reduce(
351351
if fsdp_param_group := state._fsdp_param_group:
352352
fsdp_param_group.all_reduce_grads = requires_all_reduce
353353

354+
def set_reshard_after_forward(
355+
self, reshard_after_forward: bool, recurse: bool = True
356+
) -> None:
357+
"""
358+
Sets if the module should reshard parameters after forward. This can be
359+
used to change the ``reshard_after_forward`` FSDP arg at runtime. For
360+
example, this can be used to set the FSDP root module's value to
361+
``True`` (since it is otherwise specially set to ``False``), or it can
362+
set an FSDP module's value to ``False`` for running evals and set back
363+
to ``True`` for training.
364+
365+
Args:
366+
reshard_after_forward (bool): Whether to reshard parameters after
367+
forward.
368+
recurse (bool): Whether to set for all FSDP submodules or just the
369+
passed-in module.
370+
"""
371+
self_module = cast(nn.Module, self)
372+
modules = list(self_module.modules()) if recurse else [self_module]
373+
for module in modules:
374+
if isinstance(module, FSDPModule):
375+
state = module._get_fsdp_state()
376+
if fsdp_param_group := state._fsdp_param_group:
377+
fsdp_param_group.post_forward_mesh_info = (
378+
_get_post_forward_mesh_info(
379+
reshard_after_forward, fsdp_param_group.mesh_info
380+
)
381+
)
382+
354383
def set_reshard_after_backward(
355384
self, reshard_after_backward: bool, *, recurse: bool = True
356385
) -> None:

0 commit comments

Comments
 (0)
0