1
1
# Owner(s): ["oncall: distributed"]
2
2
3
+ import contextlib
3
4
from copy import deepcopy
4
5
from functools import partial
5
6
6
7
import torch
7
8
import torch .nn as nn
8
- from torch .utils .checkpoint import checkpoint
9
9
from torch .distributed ._fsdp .fully_sharded_data_parallel import (
10
10
FullyShardedDataParallel as FSDP ,
11
11
CPUOffload ,
25
25
parametrize ,
26
26
instantiate_parametrized_tests ,
27
27
)
28
+ from torch .utils .checkpoint import checkpoint
28
29
29
30
30
31
class TestFSDPCheckpoint (FSDPTest ):
31
-
32
32
class SequentialModule (nn .Module ):
33
- def __init__ (self , checkpoint_layer = False , wrap_fsdp = False , * fsdp_args , ** fsdp_kwargs ):
33
+ def __init__ (
34
+ self ,
35
+ checkpoint_layer = False ,
36
+ offload_activations = False ,
37
+ wrap_fsdp = False ,
38
+ * fsdp_args ,
39
+ ** fsdp_kwargs ,
40
+ ):
34
41
torch .manual_seed (0 )
35
42
torch .cuda .manual_seed (0 )
36
43
super ().__init__ ()
@@ -39,15 +46,16 @@ def __init__(self, checkpoint_layer=False, wrap_fsdp=False, *fsdp_args, **fsdp_k
39
46
l3 = nn .Linear (3 , 3 ).cuda ()
40
47
41
48
if checkpoint_layer :
42
- l1 = checkpoint_wrapper (l1 )
43
- l2 = checkpoint_wrapper (l2 )
44
- l3 = checkpoint_wrapper (l3 )
49
+ ckpt_wrapper = partial (
50
+ checkpoint_wrapper , offload_to_cpu = offload_activations
51
+ )
52
+
53
+ l1 = ckpt_wrapper (l1 )
54
+ l2 = ckpt_wrapper (l2 )
55
+ l3 = ckpt_wrapper (l3 )
45
56
46
57
fsdp_wrapper = partial (
47
- _maybe_wrap_fsdp ,
48
- wrap_fsdp = wrap_fsdp ,
49
- * fsdp_args ,
50
- ** fsdp_kwargs
58
+ _maybe_wrap_fsdp , wrap_fsdp = wrap_fsdp , * fsdp_args , ** fsdp_kwargs
51
59
)
52
60
self .ffn = nn .Sequential (
53
61
fsdp_wrapper (l1 ),
@@ -58,7 +66,6 @@ def __init__(self, checkpoint_layer=False, wrap_fsdp=False, *fsdp_args, **fsdp_k
58
66
def forward (self , x ):
59
67
return self .ffn (x )
60
68
61
-
62
69
def _verify_parity (self , losses , outputs , models ):
63
70
assert losses
64
71
assert outputs
@@ -79,18 +86,23 @@ def _verify_parity(self, losses, outputs, models):
79
86
@skip_if_lt_x_gpu (2 )
80
87
@parametrize (
81
88
"cpu_offload" ,
82
- [CPUOffload (offload_params = True ), CPUOffload (offload_params = False )]
89
+ [CPUOffload (offload_params = True ), CPUOffload (offload_params = False )],
83
90
)
84
- def test_checkpoint_fsdp_wrapping (self , cpu_offload ):
91
+ @parametrize ("offload_activations" , [True , False ])
92
+ def test_checkpoint_fsdp_wrapping (self , cpu_offload , offload_activations ):
85
93
# Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
86
94
ckpt_sequential_wrapped_fsdp = checkpoint_wrapper (
87
95
TestFSDPCheckpoint .SequentialModule (
88
96
wrap_fsdp = True , cpu_offload = cpu_offload
89
- )
97
+ ),
98
+ offload_to_cpu = offload_activations ,
90
99
)
91
100
# Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
92
101
inner_ckpt = TestFSDPCheckpoint .SequentialModule (
93
- checkpoint_layer = True , wrap_fsdp = True , cpu_offload = cpu_offload
102
+ checkpoint_layer = True ,
103
+ offload_activations = offload_activations ,
104
+ wrap_fsdp = True ,
105
+ cpu_offload = cpu_offload ,
94
106
)
95
107
96
108
baseline = TestFSDPCheckpoint .SequentialModule (
@@ -101,17 +113,29 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload):
101
113
# flag set.
102
114
inp = torch .randn (10 , 3 , device = torch .cuda .current_device (), requires_grad = True )
103
115
104
- models = [
105
- ckpt_sequential_wrapped_fsdp ,
106
- inner_ckpt ,
107
- baseline
108
- ]
116
+ models = [ckpt_sequential_wrapped_fsdp , inner_ckpt , baseline ]
109
117
110
- for _ in range (2 ):
118
+ offload_to_cpu_event = "Memcpy DtoH"
119
+
120
+ for i in range (2 ):
111
121
losses = []
112
122
outputs = []
113
123
for m in models :
114
- out = m (inp )
124
+ check_offload = m != baseline and i == 0 and offload_activations
125
+ profiler_ctx = (
126
+ torch .profiler .profile (use_cuda = True )
127
+ if check_offload
128
+ else contextlib .suppress ()
129
+ )
130
+ with profiler_ctx as prof :
131
+ out = m (inp )
132
+
133
+ if check_offload :
134
+ event_names = [event .name for event in prof .events ()]
135
+ offload_occured = any (
136
+ offload_to_cpu_event in name for name in event_names
137
+ )
138
+ self .assertTrue (offload_occured )
115
139
loss = out .sum ()
116
140
loss .backward ()
117
141
losses .append (loss )
@@ -122,16 +146,23 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload):
122
146
@skip_if_lt_x_gpu (2 )
123
147
@parametrize (
124
148
"cpu_offload" ,
125
- [CPUOffload (offload_params = True ), CPUOffload (offload_params = False )]
149
+ [CPUOffload (offload_params = True ), CPUOffload (offload_params = False )],
126
150
)
127
- def test_basic_checkpoint_end_to_end (self , cpu_offload ):
151
+ @parametrize ("offload_activations" , [True , False ])
152
+ def test_basic_checkpoint_end_to_end (self , cpu_offload , offload_activations ):
128
153
seq = TestFSDPCheckpoint .SequentialModule ().to (torch .cuda .current_device ())
129
154
# Runs FSDP with no checkpointing
130
155
fsdp_only_seq = FSDP (deepcopy (seq ), cpu_offload = cpu_offload )
131
156
# Runs checkpoint-wrapped FSDP
132
- checkpointed_fsdp = checkpoint_wrapper (FSDP (deepcopy (seq ), cpu_offload = cpu_offload ))
157
+ checkpointed_fsdp = checkpoint_wrapper (
158
+ FSDP (deepcopy (seq ), cpu_offload = cpu_offload ),
159
+ offload_to_cpu = offload_activations ,
160
+ )
133
161
# Runs FSDP-wrapped checkpointed module
134
- fsdp_wrapped_checkpoint = FSDP (checkpoint_wrapper (deepcopy (seq )), cpu_offload = cpu_offload )
162
+ fsdp_wrapped_checkpoint = FSDP (
163
+ checkpoint_wrapper (deepcopy (seq ), offload_to_cpu = offload_activations ),
164
+ cpu_offload = cpu_offload ,
165
+ )
135
166
# Runs FSDP with manual calls to checkpoint.
136
167
fsdp_call_checkpoint = FSDP (deepcopy (seq ), cpu_offload = cpu_offload )
137
168
# note that reentrant-based checkpointing requires inputs to have grad
@@ -143,17 +174,39 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload):
143
174
fsdp_only_seq ,
144
175
checkpointed_fsdp ,
145
176
fsdp_wrapped_checkpoint ,
146
- fsdp_call_checkpoint
177
+ fsdp_call_checkpoint ,
147
178
]
148
179
149
- for _ in range (6 ):
180
+ offload_to_cpu_event = "Memcpy DtoH"
181
+
182
+ for i in range (6 ):
150
183
losses = []
151
184
outputs = []
152
185
for m in models :
153
- if m == fsdp_call_checkpoint :
154
- out = checkpoint (m , inp )
155
- else :
156
- out = m (inp )
186
+ check_offload = m != fsdp_only_seq and i == 0 and offload_activations
187
+ profiler_ctx = (
188
+ torch .profiler .profile (use_cuda = True )
189
+ if check_offload
190
+ else contextlib .suppress ()
191
+ )
192
+ with profiler_ctx as prof :
193
+ if m == fsdp_call_checkpoint :
194
+ offload_ctx = (
195
+ torch .autograd .graph .save_on_cpu (pin_memory = True )
196
+ if offload_activations
197
+ else contextlib .suppress ()
198
+ )
199
+ with offload_ctx :
200
+ out = checkpoint (m , inp )
201
+ else :
202
+ out = m (inp )
203
+
204
+ if check_offload :
205
+ event_names = [event .name for event in prof .events ()]
206
+ offload_occured = any (
207
+ offload_to_cpu_event in name for name in event_names
208
+ )
209
+ self .assertTrue (offload_occured )
157
210
loss = out .sum ()
158
211
loss .backward ()
159
212
losses .append (loss )
0 commit comments