8000 [PT-D][Checkpointing] Move distributed checkpointing from torch.distr… · pytorch/pytorch@aee96bb · GitHub
[go: up one dir, main page]

Skip to content

Commit aee96bb

Browse files
wz337pytorchmergebot
authored andcommitted
[PT-D][Checkpointing] Move distributed checkpointing from torch.distributed._shard.checkpoint to torch.distributed.checkpoint (#88698)
Context in RFC: #86620 .rst file will be finalized in subsequent PRs. Pull Request resolved: #88698 Approved by: https://github.com/wanchaol
1 parent 6b521bb commit aee96bb

20 files changed

+389
-159
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Distributed Checkpoint
2+
========================
3+
4+
.. automodule:: torch.distributed.checkpoint

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Features described in this documentation are classified by release status:
7070
torch.distributed.elastic <distributed.elastic>
7171
torch.distributed.fsdp <fsdp>
7272
torch.distributed.optim <distributed.optim>
73+
torch.distributed.checkpoint <distributed.checkpoint>
7374
torch.distributions <distributions>
7475
torch.fft <fft>
7576
futures

test/distributed/_shard/checkpoint/test_checkpoint.py renamed to test/distributed/checkpoint/test_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@
2020

2121
from torch.distributed._shard import sharded_tensor
2222

23-
from torch.distributed._shard.checkpoint.default_planner import (
23+
from torch.distributed.checkpoint.default_planner import (
2424
_create_default_local_metadata,
2525
)
2626

27-
from torch.distributed._shard.checkpoint.metadata import (
27+
from torch.distributed.checkpoint.metadata import (
2828
BytesStorageMetadata,
2929
Metadata,
3030
TensorStorageMetadata,
3131
)
3232

33-
from torch.distributed._shard.checkpoint.planner import (
33+
from torch.distributed.checkpoint.planner import (
3434
SavePlan,
3535
SavePlanner,
3636
LoadPlan,

test/distributed/_shard/checkpoint/test_file_system_checkpoint.py renamed to test/distributed/checkpoint/test_file_system_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
run_tests,
3232
)
3333

34-
from torch.distributed._shard.checkpoint import (
34+
from torch.distributed.checkpoint import (
3535
FileSystemReader,
3636
FileSystemWriter,
3737
load_state_dict,

test/distributed/_shard/checkpoint/test_file_system_checkpoint_cpu.py renamed to test/distributed/checkpoint/test_file_system_checkpoint_cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
run_tests,
3232
)
3333

34-
from torch.distributed._shard.checkpoint import (
34+
from torch.distributed.checkpoint import (
3535
FileSystemReader,
3636
FileSystemWriter,
3737
load_state_dict,

test/distributed/_shard/checkpoint/test_planner.py renamed to test/distributed/checkpoint/test_planner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44

55
import torch
6-
from torch.distributed._shard.checkpoint.planner import LoadItemType, WriteItemType
6+
from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
77

88
from torch.distributed._shard.sharded_tensor import (
99
Shard,
@@ -18,13 +18,13 @@
1818
TEST_WITH_DEV_DBG_ASAN,
1919
run_tests,
2020
)
21-
from torch.distributed._shard.checkpoint.metadata import BytesStorageMetadata, MetadataIndex, TensorStorageMetadata
21+
from torch.distributed.checkpoint.metadata import BytesStorageMetadata, MetadataIndex, TensorStorageMetadata
2222
from torch.testing._internal.distributed.distributed_utils import (
2323
with_fake_comms,
2424
with_dist
2525
)
2626

27-
from torch.distributed._shard.checkpoint.default_planner import (
27+
from torch.distributed.checkpoint.default_planner import (
2828
create_default_global_save_plan,
2929
create_default_local_save_plan,
3030
create_default_local_load_plan,

test/distributed/_shard/checkpoint/test_utils.py renamed to test/distributed/checkpoint/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
TEST_WITH_DEV_DBG_ASAN,
1818
run_tests,
1919
)
20-
from torch.distributed._shard.checkpoint.utils import find_state_dict_object
21-
from torch.distributed._shard.checkpoint.metadata import MetadataIndex
20+
from torch.distributed.checkpoint.utils import find_state_dict_object
21+
from torch.distributed.checkpoint.metadata import MetadataIndex
2222
from torch.testing._internal.distributed.distributed_utils import (
2323
with_fake_comms
2424
)
Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,12 @@
1-
from .metadata import (
2-
TensorStorageMetadata,
3-
BytesStorageMetadata,
4-
ChunkStorageMetadata,
5-
Metadata,
6-
)
7-
from .state_dict_loader import load_state_dict
8-
from .state_dict_saver import save_state_dict
9-
from .storage import StorageReader, StorageWriter
10-
from .filesystem import FileSystemReader, FileSystemWriter
11-
from .api import CheckpointException
12-
1+
# Keep old package for BC purposes, this file should be removed once
2+
# everything moves to the `torch.distributed.checkpoint` package.
3+
import sys
4+
import torch
5+
import warnings
136

14-
from .planner import (
15-
SavePlanner,
16-
LoadPlanner,
17-
SavePlan,
18-
LoadPlan,
19-
ReadItem,
20-
WriteItem,
7+
from torch.distributed.checkpoint import * # noqa: F403
8+
warnings.warn(
9+
"torch.distributed._shard.checkpoint will be deprecated, use torch.distributed.checkpoint instead",
10+
DeprecationWarning
2111
)
12+
sys.modules['torch.distributed._shard.checkpoint'] = torch.distributed.checkpoint
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from .metadata import (
2+
TensorStorageMetadata,
3+
BytesStorageMetadata,
4+
ChunkStorageMetadata,
5+
Metadata,
6+
)
7+
from .state_dict_loader import load_state_dict
8+
from .state_dict_saver import save_state_dict
9+
from .storage import StorageReader, StorageWriter
10+
from .filesystem import FileSystemReader, FileSystemWriter
11+
from .api import CheckpointException
12+
13+
14+
from .planner import (
15+
SavePlanner,
16+
LoadPlanner,
17+
SavePlan,
18+
LoadPlan,
19+
ReadItem,
20+
WriteItem,
21+
)

torch/distributed/_shard/checkpoint/api.py renamed to torch/distributed/checkpoint/api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,28 @@
33

44
WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
55

6+
__all__ = ["CheckpointException"]
7+
8+
69
def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION:
710
return (exc, tb.extract_tb(exc.__traceback__))
811

12+
913
def _is_wrapped_exception(obj: Any) -> bool:
1014
if not isinstance(obj, tuple):
1115
return False
1216
if len(obj) != 2:
1317
return False
14-
return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
18+
return isinstance(obj[0], BaseException) and isinstance(
19+
obj[1], tb.StackSummary
20+
)
21+
1522

1623
class CheckpointException(BaseException):
1724
"""
1825
Exception raised if failure was detected as part of a checkpoint load or save.
1926
"""
27+
2028
def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
2129
super().__init__(msg, failures)
2230
self._failures = failures

torch/distributed/_shard/checkpoint/default_planner.py renamed to torch/distributed/checkpoint/default_planner.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,37 +24,53 @@
2424
MetadataIndex,
2525
Metadata,
2626
STATE_DICT_TYPE,
27-
STORAGE_TYPES
27+
STORAGE_TYPES,
2828
)
2929

3030
from .planner_helpers import (
3131
_create_read_items,
3232
_create_write_items,
33-
_create_default_metadata_only_plan
33+
_create_default_metadata_only_plan,
3434
)
3535

36-
from .utils import (
37-
find_state_dict_object
38-
)
36+
from .utils import find_state_dict_object
37+
38+
__all__ = [
39+
"DefaultSavePlanner",
40+
"DefaultLoadPlanner",
41+
"create_default_local_load_plan",
42+
"create_default_global_load_plan",
43+
"create_default_local_save_plan",
44+
"create_default_global_save_plan",
45+
]
46+
3947

4048
class DefaultSavePlanner(SavePlanner):
4149
def init(self, state_dict: Dict[str, Any], is_coordinator: bool) -> None:
4250
self.state_dict = state_dict
4351
self.is_coordinator = is_coordinator
4452

4553
def create_local_plan(self) -> SavePlan:
46-
self.plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
54+
self.plan = create_default_local_save_plan(
55+
self.state_dict, self.is_coordinator
56+
)
4757
return self.plan
4858

49-
def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
50-
self.global_plan, self.metadata = create_default_global_save_plan(all_plans)
59+
def create_global_plan(
60+
self, all_plans: List[SavePlan]
61+
) -> Tuple[List[SavePlan], Metadata]:
62+
self.global_plan, self.metadata = create_default_global_save_plan(
63+
all_plans
64+
)
5165
return self.global_plan, self.metadata
5266

5367
def finish_plan(self, new_plan: SavePlan) -> SavePlan:
5468
self.plan = new_plan
5569
return new_plan
5670

57-
def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
71+
def resolve_data(
72+
self, write_item: WriteItem
73+
) -> Union[torch.Tensor, io.BytesIO]:
5874
object = self.lookup_object(write_item.index)
5975
return self.transform_object(write_item, object)
6076

@@ -76,7 +92,12 @@ def transform_object(self, write_item: WriteItem, object: Any):
7692

7793

7894
class DefaultLoadPlanner(LoadPlanner):
79-
def init(self, state_dict: STATE_DICT_TYPE, metadata: Metadata, is_coordinator: bool) -> None:
95+
def init(
96+
self,
97+
state_dict: STATE_DICT_TYPE,
98+
metadata: Metadata,
99+
is_coordinator: bool,
100+
) -> None:
80101
self.state_dict = state_dict
81102
self.metadata = metadata
82103
self.is_coordinator = is_coordinator
@@ -110,7 +131,9 @@ def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
110131
"""
111132
This is an extension from the planner interface to make it easy to extend the default planner
112133
"""
113-
return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
134+
return narrow_tensor_by_index(
135+
tensor, read_item.dest_offsets, read_item.lengths
136+
)
114137

115138

116139
def create_default_local_load_plan(
@@ -133,7 +156,10 @@ def create_default_local_load_plan(
133156

134157
return LoadPlan(requests)
135158

136-
def create_default_global_load_plan(all_plans: List[LoadPlan]) -> List[LoadPlan]:
159+
160+
def create_default_global_load_plan(
161+
all_plans: List[LoadPlan],
162+
) -> List[LoadPlan]:
137163
"""
138164
Create global load plan used by DefaultLoadPlanner.
139165
@@ -142,7 +168,10 @@ def create_default_global_load_plan(all_plans: List[LoadPlan]) -> List[LoadPlan]
142168
"""
143169
return all_plans
144170

145-
def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: bool) -> SavePlan:
171+
172+
def create_default_local_save_plan(
173+
state_dict: Dict[str, Any], is_coordinator: bool
174+
) -> SavePlan:
146175
"""
147176
Create the ``SavePlan`` used by DefaultSavePlanner.
148177
@@ -157,7 +186,10 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b
157186
requests += _create_write_items(fqn, obj)
158187
return SavePlan(requests)
159188

160-
def create_default_global_save_plan(all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
189+
190+
def create_default_global_save_plan(
191+
all_plans: List[SavePlan],
192+
) -> Tuple[List[SavePlan], M 10000 etadata]:
161193
"""
162194
Create the global plan and metadata used by DefaultSavePlanner.
163195
@@ -180,21 +212,29 @@ def create_default_global_save_plan(all_plans: List[SavePlan]) -> Tuple[List[Sav
180212
assert item.tensor_data is not None
181213
tensor_md = cast(
182214
TensorStorageMetadata,
183-
md.setdefault(item.index.fqn, TensorStorageMetadata(
184-
properties=item.tensor_data.properties,
185-
size=item.tensor_data.size,
186-
chunks=[],
187-
))
215+
md.setdefault(
216+
item.index.fqn,
217+
TensorStorageMetadata(
218+
properties=item.tensor_data.properties,
219+
size=item.tensor_data.size,
220+
chunks=[],
221+
),
222+
),
223+
)
224+
new_index = dataclasses.replace(
225+
item.index, index=len(tensor_md.chunks)
188226
)
189-
new_index = dataclasses.replace(item.index, index=len(tensor_md.chunks))
190227
new_item = dataclasses.replace(item, index=new_index)
191228
new_items.append(new_item)
192229

193-
assert item.tensor_data.chunk is not None, f"Cannot create MD for tensor without bounds. FQN: {item.index.fqn}"
230+
assert (
231+
item.tensor_data.chunk is not None
232+
), f"Cannot create MD for tensor without bounds. FQN: {item.index.fqn}"
194233
tensor_md.chunks.append(item.tensor_data.chunk)
195234
new_plans.append(dataclasses.replace(plan, items=new_items))
196235
return (new_plans, Metadata(md))
197236

237+
198238
def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
199239
"""
200240
Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``.

0 commit comments

Comments
 (0)
0