8000 [DCP] Always flatten mapping even if no tensors present (#125335) · pytorch/pytorch@6f1e3a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6f1e3a6

Browse files
feginpytorchmergebot
authored andcommitted
[DCP] Always flatten mapping even if no tensors present (#125335)
Summary: Right now DCP only flatten a mapping (e.g., di 8000 ct) if that mapping has tensor objects. This behavior is odd as users may save different non-tensor objects on different ranks. Without flattening the mappings, we may lose these non-tensor objects. One use case is dataloader state_dict. We may also want to do so for a list/tuple. But this will cause extra pickles. So we don't do this for now. Pull Request resolved: #125335 Approved by: https://github.com/LucasLLC, https://github.com/wz337 ghstack dependencies: #125333, #125501, #125334
1 parent 790f43c commit 6f1e3a6

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

test/distributed/checkpoint/test_nested_dict.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_flattening_round_trip(self) -> None:
1313
state_dict = {
1414
"key0": 1,
1515
"key1": [1, 2],
16-
"key2": {1: 2, 2: 3},
16+
"key2": {"1": 2, "2": 3},
1717
"key3": torch.tensor([1]),
1818
"key4": [[torch.tensor(2), "x"], [1, 2, 3], {"key6": [44]}],
1919
}
@@ -24,7 +24,7 @@ def test_flattening_round_trip(self) -> None:
2424
{
2525
'key0': 1,
2626
'key1': [1, 2],
27-
'key2': {1: 2, 2: 3},
27+
'key2': {'1': 2, '2': 3},
2828
'key3': tensor([1]),
2929
'key4.0.0': tensor(2),
3030
'key4.0.1': 'x',
@@ -55,7 +55,9 @@ def test_mapping(self) -> None:
5555
self.assertEqual(("k2", 0), mapping["k2.0"])
5656
self.assertEqual(("k2", 1), mapping["k2.1"])
5757
self.assertEqual(("k2", 2, 0, "k3"), mapping["k2.2.0.k3"])
58-
self.assertEqual(("k3",), mapping["k3"])
58+
self.assertEqual(("k3", 0), mapping["k3.0"])
59+
self.assertEqual(("k3", 1), mapping["k3.1"])
60+
self.assertEqual(("k3", 2, 0, "k3"), mapping["k3.2.0.k3"])
5961

6062

6163
if __name__ == "__main__":

test/distributed/checkpoint/test_traverse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ def collect_data(path, value):
3333
self.assertIn(("key1",), data)
3434
self.assertEqual(data[("key1",)], [1, 2])
3535

36-
self.assertIn(("key2",), data)
37-
self.assertEqual(data[("key2",)], {1: 2, 2: 3})
36+
self.assertIn(("key2", "1"), data)
37+
self.assertEqual(data[("key2", "1")], 2)
38+
self.assertIn(("key2", "2"), data)
39+
self.assertEqual(data[("key2", "2")], 3)
3840

3941
self.assertIn(("key3",), data)
4042
self.assertEqual(data[("key3",)], torch.tensor([1]))

torch/distributed/checkpoint/_traverse.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ def traverse_state_dict(
3939
) -> None:
4040
"""
4141
Invoke ``visitor`` for each value recursively in ``state_dict``.
42-
43-
Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
44-
to false for all elements.
45-
By default, all collections with at least one ``torch.Tensor`` element are traversed.
46-
Visitor takes a path argument that is a tuple of the keys used to reach it.
42+
Mapping, list, and tuple will be flattened and other value types are treated
43+
as the terminal values and will invoke ``visitor``.
44+
Mapping is treated as non terminal node and will be flattened.
45+
List and tuple, on the other hand, will not be flattened unless containing other
46+
mapping containers or tensors.
4747
"""
4848

4949
# a value is terminal if it has no other containers values inside it
5050
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
5151
values: Collection[STATE_DICT_ITEM]
5252
if isinstance(value, Mapping):
53-
values = value.values()
53+
return False
5454
elif isinstance(value, list):
5555
values = value
5656
else:
@@ -64,12 +64,12 @@ def _is_terminal(value: STATE_DICT_ITEM) -> bool:
6464
return True
6565

6666
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
67-
if _is_terminal(value):
68-
visitor(path, value)
69-
elif isinstance(value, Mapping):
67+
if isinstance(value, Mapping):
7068
for k, v in value.items():
7169
_traverse_obj(path + (str(k),), v)
72-
elif isinstance(value, list):
70+
elif _is_terminal(value):
71+
visitor(path, value)
72+
elif isinstance(value, (list, tuple)):
7373
for i, v in enumerate(value):
7474
_traverse_obj(path + (i,), v)
7575

0 commit comments

Comments
 (0)
0