8000 [FSDP/Checkpoint] Activation offload support in checkpoint_wrapper (#… · pytorch/pytorch@a197f3f · GitHub
[go: up one dir, main page]

Skip to content

Commit a197f3f

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
[FSDP/Checkpoint] Activation offload support in checkpoint_wrapper (#70165)
Summary: Pull Request resolved: #70165 Implements activation offload support in checkpoint_wrapper API via save_on_cpu hooks. We avoid modifying the torch.utils.checkpoint implementation and instead compose offload + checkpoint using the save_on_cpu hook for the former. ghstack-source-id: 146078900 Test Plan: CI Reviewed By: zhaojuanmao Differential Revision: D33228820 fbshipit-source-id: 98b4da0828462c41c381689ee07360ad014e808a
1 parent e428a90 commit a197f3f

File tree

2 files changed

+111
-40
lines changed

2 files changed

+111
-40
lines changed

test/distributed/fsdp/test_fsdp_checkpoint.py

Lines changed: 85 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Owner(s): ["oncall: distributed"]
22

3+
import contextlib
34
from copy import deepcopy
45
from functools import partial
56

67
import torch
78
import torch.nn as nn
8-
from torch.utils.checkpoint import checkpoint
99
from torch.distributed._fsdp.fully_sharded_data_parallel import (
1010
FullyShardedDataParallel as FSDP,
1111
CPUOffload,
@@ -25,12 +25,19 @@
2525
parametrize,
2626
instantiate_parametrized_tests,
2727
)
28+
from torch.utils.checkpoint import checkpoint
2829

2930

3031
class TestFSDPCheckpoint(FSDPTest):
31-
3232
class SequentialModule(nn.Module):
33-
def __init__(self, checkpoint_layer=False, wrap_fsdp=False, *fsdp_args, **fsdp_kwargs):
33+
def __init__(
34+
self,
35+
checkpoint_layer=False,
36+
offload_activations=False,
37+
wrap_fsdp=False,
38+
*fsdp_args,
39+
**fsdp_kwargs,
40+
):
3441
torch.manual_seed(0)
3542
torch.cuda.manual_seed(0)
3643
super().__init__()
@@ -39,15 +46,16 @@ def __init__(self, checkpoint_layer=False, wrap_fsdp=False, *fsdp_args, **fsdp_k
3946
l3 = nn.Linear(3, 3).cuda()
4047

4148
if checkpoint_layer:
42-
l1 = checkpoint_wrapper(l1)
43-
l2 = checkpoint_wrapper(l2)
44-
l3 = checkpoint_wrapper(l3)
49+
ckpt_wrapper = partial(
50+
checkpoint_wrapper, offload_to_cpu=offload_activations
51+
)
52+
53+
l1 = ckpt_wrapper(l1)
54+
l2 = ckpt_wrapper(l2)
55+
l3 = ckpt_wrapper(l3)
4556

4657
fsdp_wrapper = partial(
47-
_maybe_wrap_fsdp,
48-
wrap_fsdp=wrap_fsdp,
49-
*fsdp_args,
50-
**fsdp_kwargs
58+
_maybe_wrap_fsdp, wrap_fsdp=wrap_fsdp, *fsdp_args, **fsdp_kwargs
5159
)
5260
self.ffn = nn.Sequential(
5361
fsdp_wrapper(l1),
@@ -58,7 +66,6 @@ def __init__(self, checkpoint_layer=False, wrap_fsdp=False, *fsdp_args, **fsdp_k
5866
def forward(self, x):
5967
return self.ffn(x)
6068

61-
6269
def _verify_parity(self, losses, outputs, models):
6370
assert losses
6471
assert outputs
@@ -79,18 +86,23 @@ def _verify_parity(self, losses, outputs, models):
7986
@skip_if_lt_x_gpu(2)
8087
@parametrize(
8188
"cpu_offload",
82-
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
89+
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
8390
)
84-
def test_checkpoint_fsdp_wrapping(self, cpu_offload):
91+
@parametrize("offload_activations", [True, False])
92+
def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
8593
# Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
8694
ckpt_sequential_wrapped_fsdp = checkpoint_wrapper(
8795
TestFSDPCheckpoint.SequentialModule(
8896
wrap_fsdp=True, cpu_offload=cpu_offload
89-
)
97+
),
98+
offload_to_cpu=offload_activations,
9099
)
91100
# Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
92101
inner_ckpt = TestFSDPCheckpoint.SequentialModule(
93-
checkpoint_layer=True, wrap_fsdp=True, cpu_offload=cpu_offload
102+
checkpoint_layer=True,
103+
offload_activations=offload_activations,
104+
wrap_fsdp=True,
105+
cpu_offload=cpu_offload,
94106
)
95107

96108
baseline = TestFSDPCheckpoint.SequentialModule(
@@ -101,17 +113,29 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload):
101113
# flag set.
102114
inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)
103115

104-
models = [
105-
ckpt_sequential_wrapped_fsdp,
106-
inner_ckpt,
107-
baseline
108-
]
116+
models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline]
109117

110-
for _ in range(2):
118+
offload_to_cpu_event = "Memcpy DtoH"
119+
120+
for i in range(2):
111121
losses = []
112122
outputs = []
113123
for m in models:
114-
out = m(inp)
124+
check_offload = m != baseline and i == 0 and offload_activations
125+
profiler_ctx = (
126+
torch.profiler.profile(use_cuda=True)
127+
if check_offload
128+
else contextlib.suppress()
129+
)
130+
with profiler_ctx as prof:
131+
out = m(inp)
132+
133+
if check_offload:
134+
event_names = [event.name for event in prof.events()]
135+
offload_occured = any(
136+
offload_to_cpu_event in name for name in event_names
137+
)
138+
self.assertTrue(offload_occured)
115139
loss = out.sum()
116140
loss.backward()
117141
losses.append(loss)
@@ -122,16 +146,23 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload):
122146
@skip_if_lt_x_gpu(2)
123147
@parametrize(
124148
"cpu_offload",
125-
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
149+
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
126150
)
127-
def test_basic_checkpoint_end_to_end(self, cpu_offload):
151+
@parametrize("offload_activations", [True, False])
152+
def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
128153
seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
129154
# Runs FSDP with no checkpointing
130155
fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
131156
# Runs checkpoint-wrapped FSDP
132-
checkpointed_fsdp = checkpoint_wrapper(FSDP(deepcopy(seq), cpu_offload=cpu_offload))
157+
checkpointed_fsdp = checkpoint_wrapper(
158+
FSDP(deepcopy(seq), cpu_offload=cpu_offload),
159+
offload_to_cpu=offload_activations,
160+
)
133161
# Runs FSDP-wrapped checkpointed module
134-
fsdp_wrapped_checkpoint = FSDP(checkpoint_wrapper(deepcopy(seq)), cpu_offload=cpu_offload)
162+
fsdp_wrapped_checkpoint = FSDP(
163+
checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations),
164+
cpu_offload=cpu_offload,
165+
)
135166
# Runs FSDP with manual calls to checkpoint.
136167
fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
137168
# note that reentrant-based checkpointing requires inputs to have grad
@@ -143,17 +174,39 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload):
143174
fsdp_only_seq,
144175
checkpointed_fsdp,
145176
fsdp_wrapped_checkpoint,
146-
fsdp_call_checkpoint
177+
fsdp_call_checkpoint,
147178
]
148179

149-
for _ in range(6):
180+
offload_to_cpu_event = "Memcpy DtoH"
181+
182+
for i in range(6):
150183
losses = []
151184
outputs = []
152185
for m in models:
153-
if m == fsdp_call_checkpoint:
154-
out = checkpoint(m, inp)
155-
else:
156-
out = m(inp)
186+
check_offload = m != fsdp_only_seq and i == 0 and offload_activations
187+
profiler_ctx = (
188+
torch.profiler.profile(use_cuda=True)
189+
if check_offload
190+
else contextlib.suppress()
191+
)
192+
with profiler_ctx as prof:
193+
if m == fsdp_call_checkpoint:
194+
offload_ctx = (
195+
torch.autograd.graph.save_on_cpu(pin_memory=True)
196+
if offload_activations
197+
else contextlib.suppress()
198+
)
199+
with offload_ctx:
200+
out = checkpoint(m, inp)
201+
else:
202+
out = m(inp)
203+
204+
if check_offload:
205+
event_names = [event.name for event in prof.events()]
206+
offload_occured = any(
207+
offload_to_cpu_event in name for name in event_names
208+
)
209+
self.assertTrue(offload_occured)
157210
loss = out.sum()
158211
loss.backward()
159212
losses.append(loss)

torch/distributed/algorithms/_checkpoint/_checkpoint_wrapper.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from enum import Enum, auto
2+
from contextlib import suppress
23

34
import torch
5+
from torch.autograd.graph import save_on_cpu
46
from torch.utils.checkpoint import checkpoint
57

68

@@ -17,22 +19,28 @@ def __init__(
1719
self,
1820
mod: torch.nn.Module,
1921
checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT,
22+
offload_to_cpu: bool = False,
2023
):
2124
super().__init__()
2225
self.mod = mod
2326
self.checkpoint_impl = checkpoint_impl
27+
self.offload_to_cpu = offload_to_cpu
2428

2529
def forward(self, *args, **kwargs):
26-
return checkpoint(
27-
self.mod,
28-
use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
29-
*args,
30-
**kwargs,
31-
)
30+
offload_mgr = save_on_cpu(pin_memory=True) if self.offload_to_cpu else suppress()
31+
with offload_mgr: # type: ignore[attr-defined]
32+
return checkpoint(
33+
self.mod,
34+
use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
35+
*args,
36+
**kwargs,
37+
)
3238

3339

3440
def checkpoint_wrapper(
35-
module: torch.nn.Module, checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT
41+
module: torch.nn.Module,
42+
checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT,
43+
offload_to_cpu: bool = False,
3644
) -> torch.nn.Module:
3745
"""
3846
A convenience wrapper for activation checkpointing. If the module is wrapped
@@ -48,6 +56,10 @@ def checkpoint_wrapper(
4856
checkpoint_impl (Optional[CheckpointImpl]):
4957
The checkpointing implementation to use. Currently only
5058
CheckpointImpl.REENTRANT is supported.
59+
offload_to_cpu (Optional[bool]):
60+
Whether to offload outer activations to CPU. Note that this
61+
currently only works with CheckpointImpl.REENTRANT.
62+
5163
Returns:
5264
(nn.Module):
5365
Wrapped module
@@ -58,4 +70,10 @@ def checkpoint_wrapper(
5870
"No support for non-reentrant based checkpoint implementation."
5971
)
6072

61-
return _CheckpointWrapper(module, checkpoint_impl)
73+
if offload_to_cpu and checkpoint_impl != CheckpointImpl.REENTRANT:
74+
raise ValueError(
75+
"No support for CPU offload activations and non-reentrant based "
76+
"checkpoint implementation."
77+
)
78+
79+
return _CheckpointWrapper(module, checkpoint_impl, offload_to_cpu)

0 commit comments

Comments
 (0)
0