File tree Expand file tree Collapse file tree 1 file changed +20
-1
lines changed
test/distributed/checkpoint Expand file tree Collapse file tree 1 file changed +20
-1
lines changed Original file line number Diff line number Diff line change @@ -506,7 +506,7 @@ def is_cpu(v):
506
506
507
507
@with_comms
508
508
@skip_if_lt_x_gpu (1 )
509
- def test_activation_ckpt_fqns (self ) -> None :
509
+ def test_activation_ckpt_fqns_ddp (self ) -> None :
510
510
"""Tests that activation checkpointing prefixes are removed from module names"""
511
511
model = CompositeParamModel (device = torch .device ("cuda" ))
512
512
original_keys = get_model_state_dict (model ).keys ()
@@ -517,6 +517,25 @@ def test_activation_ckpt_fqns(self) -> None:
517
517
518
518
self .assertEqual (original_keys , new_keys )
519
519
520
+ @with_comms
521
+ @skip_if_lt_x_gpu (1 )
522
+ def test_activation_ckpt_fqns_fsdp1 (self ) -> None :
523
+ self .run_subtests (
524
+ {"use_orig_params" : [True , False ]},
525
+ self ._test_activation_ckpt_fqns_fsdp1 ,
526
+ )
527
+
528
+ def _test_activation_ckpt_fqns_fsdp1 (self , use_orig_params : bool ) -> None :
529
+ """Tests that activation checkpointing prefixes are removed from module names"""
530
+ model = CompositeParamModel (device = torch .device ("cuda" ))
531
+ original_keys = get_model_state_dict (model ).keys ()
532
+
533
+ apply_activation_checkpointing (model )
534
+ model = FSDP (model , use_orig_params = use_orig_params )
535
+ new_keys = get_model_state_dict (model ).keys ()
536
+
537
+ self .assertEqual (original_keys , new_keys )
538
+
520
539
521
540
if __name__ == "__main__" :
522
541
run_tests ()
You can’t perform that action at this time.
0 commit comments