8000 [DCP] Fixes the BC issue where the traversal doesn't support versions… · pytorch/pytorch@74dcff3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 74dcff3

Browse files
committed
[DCP] Fixes the BC issue where the traversal doesn't support versions before 2.4
The original DCP doesn't flattening all the containers, which can cause issues, #125335 intends to solve the issue by flattening all the dictionaries. Unfortunately, it breaks the checkpoints that are saved before 2.4. This also shows some issues of the DCP: 1. DCP should record version in the metadata. 2. DCP should have a nice way to load old state_dict. 3. DCP should unflatten all containers (map, list) not just map. This PR only addresses issue 2 to unblock users. Issue 1 and issue 3 need to be addressed in the future. ghstack-source-id: f207aed Pull Request resolved: #134158
1 parent 2588b5e commit 74dcff3

File tree

5 files changed

+130
-8
lines changed

5 files changed

+130
-8
lines changed

test/distributed/checkpoint/test_compatibility.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,31 @@ def test_storage_meta(self) -> None:
7070
self.assertEqual(storage_meta.save_id, writer.save_id)
7171
self.assertEqual(storage_meta.load_id, reader.load_id)
7272

73+
@with_temp_dir
74+
def test_with_v_2_3(self) -> None:
75+
sd = {
76+
"a": torch.zeros(4, 4),
77+
"dict": {
78+
"dict_a": {"dict_a_1": 1, "dict_a_2": 2},
79+
"dict_b": {"dict_b_1": 1, "dict_b_2": 2},
80+
},
81+
"list": [0, 1, 2, 3, 4, 5],
82+
}
83+
load_sd = {
84+
"a": torch.ones(4, 4),
85+
"dict": {
86+
"dict_a": {"dict_a_1": 2, "dict_a_2": 4},
87+
"dict_b": {"dict_b_1": 2, "dict_b_2": 4},
88+
},
89+
"list": [10, 11, 12, 13, 14, 15],
90+
}
91+
92+
dcp._version._act_like_version = "2_3"
93+
dcp.save(sd, checkpoint_id=self.temp_dir)
94+
dcp._version._act_like_version = None
95+
dcp.load(load_sd, checkpoint_id=self.temp_dir)
96+
self.assertEqual(sd, load_sd)
97+
7398

7499
if __name__ == "__main__":
75100
run_tests()

torch/distributed/checkpoint/_nested_dict.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33

44
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
55

6-
from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
6+
from . import _version
7+
from ._traverse import (
8+
OBJ_PATH,
9+
set_element,
10+
STATE_DICT_ITEM,
11+
traverse_state_dict,
12+
traverse_state_dict_v_2_3,
13+
)
714

815

916
"""
@@ -40,7 +47,16 @@ def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
4047
flattened[new_fqn] = value
4148
mappings[new_fqn] = path
4249

43-
traverse_state_dict(state_dict, flat_copy)
50+
# We started to flatten dictionary since v2.4. But in order to not break
51+
# the checkpoints that were saved before v2.4, we need to keep the old
52+
# traversal so that we can reconstruct those checkpoints.
53+
use_v_2_3 = (
54+
_version._derived_version is not None and _version._derived_version == "2_3"
55+
)
56+
if use_v_2_3:
57+
traverse_state_dict_v_2_3(state_dict, flat_copy)
58+
else:
59+
traverse_state_dict(state_dict, flat_copy)
4460
return flattened, mappings
4561

4662

torch/distributed/checkpoint/_traverse.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,11 @@ def traverse_state_dict(
4040
) -> None:
4141
"""
4242
Invoke ``visitor`` for each value recursively in ``state_dict``.
43-
Mapping, list, and tuple will be flattened and other value types are treated
44-
as the terminal values and will invoke ``visitor``.
45-
Mapping is treated as non terminal node and will be flattened.
46-
List and tuple, on the other hand, will not be flattened unless containing other
47-
mapping containers or tensors.
43+
Mapping will be traversed and ``visitor`` will be applied to the leaf elements.
44+
``visitor`` will only be applied to elements in a list or a tuple, if the
45+
container contains tensors or mappings.
4846
"""
4947

50-
# a value is terminal if it has no other containers values inside it
5148
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
5249
values: Collection[STATE_DICT_ITEM]
5350
if isinstance(value, Mapping):
@@ -78,6 +75,49 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
7875
_traverse_obj((str(key),), value)
7976

8077

78+
def traverse_state_dict_v_2_3(
79+
state_dict: STATE_DICT_TYPE,
80+
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
81+
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
82+
) -> None:
83+
"""
84+
Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
85+
to false for all elements.
86+
By default, all collections with at least one ``torch.Tensor`` element are traversed.
87+
Visitor takes a path argument that is a tuple of the keys used to reach it.
88+
"""
89+
90+
# a value is terminal if it has no other containers values inside it
91+
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
92+
values: Collection[STATE_DICT_ITEM]
93+
if isinstance(value, Mapping):
94+
values = value.values()
95+
elif isinstance(value, list):
96+
values = value
97+
else:
98+
return True
99+
100+
for entry in values:
101+
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
102+
return False
103+
if keep_traversing is not None and keep_traversing(entry):
104+
return False
105+
return True
106+
107+
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
108+
if _is_terminal(value):
109+
visitor(path, value)
110+
elif isinstance(value, Mapping):
111+
for k, v in value.items():
112+
_traverse_obj(path + (str(k),), v)
113+
elif isinstance(value, list):
114+
for i, v in enumerate(value):
115+
_traverse_obj(path + (i,), v)
116+
117+
for key, value in state_dict.items():
118+
_traverse_obj((str(key),), value)
119+
120+
81121
def set_element(
82122
root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
83123
) -> None:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates
2+
3+
from typing import Optional
4+
5+
6+
_derived_version: Optional[str] = None

torch/distributed/checkpoint/default_planner.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
)
4747
from torch.distributed.checkpoint.utils import find_state_dict_object
4848

49+
from . import _version
50+
4951

5052
logger: logging.Logger = logging.getLogger(__name__)
5153

@@ -195,6 +197,39 @@ def set_up_planner(
195197

196198
def create_local_plan(self) -> LoadPlan:
197199
assert self.metadata is not None
200+
if self.flatten_state_dict:
201+
# To support checkpoints that are saved before v2.4, we have to
202+
# differentiate if the missing keys are due to old checkpoints.
203+
# The contracts are:
204+
# 1. There are 3 cases when we found a missing key.
205+
# 1.1 Actual missing key, but allow_partial_load is False
206+
# 1.2 Actual missing key, but allow_partial load is True
207+
# 1.3 Old checkpoint, but allow_partial_load is False
208+
# 1.4 Old checkpoint, but allow_partial_load is True
209+
# 2. If we found a missing key, we first convert the keys back to
210+
# the key format of v2.3
211+
# 3. If the previous missing keys are in the v2.3 keys, we assume
212+
# this is a old checkpoint.
213+
# 4. Pass the state_dict to `create_default_local_load_plan()`,
214+
# which has the logic to check missing for allow_partial_load.
215+
# So for 1.2 and 1.4 cases, we delegate allow_partial_load check to
216+
# `create_default_local_load_plan()`. The logic here is to determine
217+
# whether the checkpoint belong to 2.3 (or before) or 2.4 (or after).
218+
current_keys = set(self.state_dict.keys())
219+
load_keys = set(self.metadata.state_dict_metadata.keys())
220+
missing_keys = load_keys - current_keys
221+
if missing_keys:
222+
_version._derived_version = "2_3"
223+
old_state_dict, old_mappings = flatten_state_dict(
224+
self.original_state_dict
225+
)
226+
old_keys = set(old_state_dict.keys())
227+
if old_keys & missing_keys:
228+
self.state_dict, self.mappings = old_state_dict, old_mappings
229+
# _derived_version is only used by flatten_state_dict now.
230+
# Set it back to None so that later we can save to a new version.
231+
_version._derived_version = None
232+
198233
return create_default_local_load_plan(
199234
self.state_dict, self.metadata, not self.allow_partial_load
200235
)

0 commit comments

Comments
 (0)
0