8000 [pytree][Easy] preserve `dict` keys in insertion order in CXX pytree … · pytorch/pytorch@9abaaad · GitHub
[go: up one dir, main page]

Skip to content

Commit 9abaaad

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[pytree][Easy] preserve dict keys in insertion order in CXX pytree (#130140)
`optree` and JAX pytree traversal the `dict` in sorted key ordering (see [Key Ordering for Dictionaries](https://github.com/metaopt/optree#key-ordering-for-dictionaries)). While in PyTorch Python pytree, we traversal the `dict` in insertion order. See also: - #114392 This aligns the behavior of CXX pytree with Python pytree. Pull Request resolved: #130140 Approved by: https://github.com/zou3519
1 parent 1f8ff94 commit 9abaaad

File tree

2 files changed

+15
-41
lines changed

2 files changed

+15
-41
lines changed

test/test_pytree.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
run_tests,
2424
skipIfTorchDynamo,
2525
subtest,
26-
TEST_WITH_TORCHDYNAMO,
2726
TestCase,
2827
)
2928

@@ -805,7 +804,6 @@ def test_treespec_equality(self):
805804
py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
806805
)
807806

808-
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
809807
def test_treespec_repr(self):
810808
# Check that it looks sane
811809
pytree = (0, [0, 0, [0]])
@@ -820,20 +818,6 @@ def test_treespec_repr(self):
820818
),
821819
)
822820

823-
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
824-
def test_treespec_repr_dynamo(self):
825-
# Check that it looks sane
826-
pytree = (0, [0, 0, [0]])
827-
_, spec = py_pytree.tree_flatten(pytree)
828-
self.assertExpectedInline(
829-
repr(spec),
830-
"""\
831-
TreeSpec(tuple, None, [*,
832-
TreeSpec(list, None, [*,
833-
*,
834-
TreeSpec(list, None, [*])])])""",
835-
)
836-
837821
@parametrize(
838822
"spec",
839823
[
@@ -1365,21 +1349,12 @@ def setUp(self):
13651349
def test_treespec_equality(self):
13661350
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
13671351

1368-
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
13691352
def test_treespec_repr(self):
13701353
# Check that it looks sane
13711354
pytree = (0, [0, 0, [0]])
13721355
_, spec = cxx_pytree.tree_flatten(pytree)
1373-
self.assertEqual(repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)")
1374-
1375-
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
1376-
def test_treespec_repr_dynamo(self):
1377-
# Check that it looks sane
1378-
pytree = (0, [0, 0, [0]])
1379-
_, spec = cxx_pytree.tree_flatten(pytree)
1380-
self.assertExpectedInline(
1381-
repr(spec),
1382-
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')",
1356+
self.assertEqual(
1357+
repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')"
13831358
)
13841359

13851360
@parametrize(

torch/utils/_cxx_pytree.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@
6161
]
6262

6363

64+
__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch")
65+
__TORCH_DICT_SESSION.__enter__() # enable globally and permanently
66+
67+
6468
T = TypeVar("T")
6569
S = TypeVar("S")
6670
U = TypeVar("U")
@@ -285,20 +289,15 @@ def tree_flatten(
285289
286290
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
287291
>>> tree_flatten(tree)
288-
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
292+
([2, 3, 4, 1, None, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch'))
289293
>>> tree_flatten(1)
290-
([1], PyTreeSpec(*, NoneIsLeaf))
294+
([1], PyTreeSpec(*, NoneIsLeaf, namespace='torch'))
291295
>>> tree_flatten(None)
292-
([None], PyTreeSpec(*, NoneIsLeaf))
293-
294-
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
295-
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
296-
if you want to keep the keys in the insertion order.
297-
296+
([None], PyTreeSpec(*, NoneIsLeaf, namespace='torch'))
298297
>>> from collections import OrderedDict
299298
>>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
300299
>>> tree_flatten(tree)
301-
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
300+
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf, namespace='torch'))
302301
303302
Args:
304303
tree (pytree): A pytree to flatten.
@@ -357,7 +356,7 @@ def tree_iter(
357356
358357
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
359358
>>> list(tree_iter(tree))
360-
[1, 2, 3, 4, None, 5]
359+
[2, 3, 4, 1, None, 5]
361360
>>> list(tree_iter(1))
362361
[1]
363362
>>> list(tree_iter(None))
@@ -392,7 +391,7 @@ def tree_leaves(
392391
393392
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
394393
>>> tree_leaves(tree)
395-
[1, 2, 3, 4, None, 5]
394+
[2, 3, 4, 1, None, 5]
396395
>>> tree_leaves(1)
397396
[1]
398397
>>> tree_leaves(None)
@@ -427,11 +426,11 @@ def tree_structure(
427426
428427
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
429428
>>> tree_structure(tree)
430-
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
429+
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')
431430
>>> tree_structure(1)
432-
PyTreeSpec(*, NoneIsLeaf)
431+
PyTreeSpec(*, NoneIsLeaf, namespace='torch')
433432
>>> tree_structure(None)
434-
PyTreeSpec(*, NoneIsLeaf)
433+
PyTreeSpec(*, NoneIsLeaf, namespace='torch')
435434
436435
Args:
437436
tree (pytree): A pytree to flatten.

0 commit comments

Comments
 (0)
0