8000 [Cherry-pick][DCP][AC] Add test for apply AC with FSDP1 (#126935) (#1… · pytorch/pytorch@cd033a1 · GitHub
[go: up one dir, main page]

Skip to content

Commit cd033a1

Browse files
authored
[Cherry-pick][DCP][AC] Add test for apply AC with FSDP1 (#126935) (#126992)
[DCP][AC] Add test for apply AC with FSDP1 (#126935) Adding test for this cherry pick. #126559 Pull Request resolved: #126935 Approved by: https://github.com/fegin
1 parent 19058a6 commit cd033a1

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

test/distributed/checkpoint/test_state_dict.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def is_cpu(v):
506506

507507
@with_comms
508508
@skip_if_lt_x_gpu(1)
509-
def test_activation_ckpt_fqns(self) -> None:
509+
def test_activation_ckpt_fqns_ddp(self) -> None:
510510
"""Tests that activation checkpointing prefixes are removed from module names"""
511511
model = CompositeParamModel(device=torch.device("cuda"))
512512
original_keys = get_model_state_dict(model).keys()
@@ -517,6 +517,25 @@ def test_activation_ckpt_fqns(self) -> None:
517517

518518
self.assertEqual(original_keys, new_keys)
519519

520+
@with_comms
521+
@skip_if_lt_x_gpu(1)
522+
def test_activation_ckpt_fqns_fsdp1(self) -> None:
523+
self.run_subtests(
524+
{"use_orig_params": [True, False]},
525+
self._test_activation_ckpt_fqns_fsdp1,
526+
)
527+
528+
def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None:
529+
"""Tests that activation checkpointing prefixes are removed from module names"""
530+
model = CompositeParamModel(device=torch.device("cuda"))
531+
original_keys = get_model_state_dict(model).keys()
532+
533+
apply_activation_checkpointing(model)
534+
model = FSDP(model, use_orig_params=use_orig_params)
535+
new_keys = get_model_state_dict(model).keys()
536+
537+
self.assertEqual(original_keys, new_keys)
538+
520539

521540
if __name__ == "__main__":
522541
run_tests()

0 commit comments

Comments
 (0)
0