8000 [pytree][Easy] preserve `dict` keys in insertion order in CXX pytree · XuehaiPan/pytorch@f677c98 · GitHub
[go: up one dir, main page]

Skip to content

Commit f677c98

Browse files
committed
[pytree][Easy] preserve dict keys in insertion order in CXX pytree
ghstack-source-id: e37db61 Pull Request resolved: pytorch#130140
1 parent 3797143 commit f677c98

File tree

2 files changed

+15
-41
lines changed

2 files changed

+15
-41
lines changed

test/test_pytree.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
run_tests,
2424
skipIfTorchDynamo,
2525
subtest,
26-
TEST_WITH_TORCHDYNAMO,
2726
TestCase,
2827
)
2928

@@ -805,7 +804,6 @@ def test_treespec_equality(self):
805804
py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
806805
)
807806

808-
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
809807
def test_treespec_repr(self):
810808
# Check that it looks sane
811809
pytree = (0, [0, 0, [0]])
@@ -820,20 +818,6 @@ def test_treespec_repr(self):
820818
),
821819
)
822820

823-
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
824-
def test_treespec_repr_dynamo(self):
825-
# Check that it looks sane
826-
pytree = (0, [0, 0, [0]])
827-
_, spec = py_pytree.tree_flatten(pytree)
828-
self.assertExpectedInline(
829-
repr(spec),
830-
"""\
831-
TreeSpec(tuple, None, [*,
832-
TreeSpec(list, None, [*,
833-
*,
834-
TreeSpec(list, None, [*])])])""",
835-
)
836-
837821
@parametrize(
838822
"spec",
839823
[
@@ -1340,21 +1324,12 @@ def setUp(self):
13401324
def test_treespec_equality(self):
13411325
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
13421326

1343-
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
13441327
def test_treespec_repr(self):
13451328
# Check that it looks sane
13461329
pytree = (0, [0, 0, [0]])
13471330
_, spec = cxx_pytree.tree_flatten(pytree)
1348-
self.assertEqual(repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)")
1349-
1350-
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
1351-
def test_treespec_repr_dynamo(self):
1352-
# Check that it looks sane
1353-
pytree = (0, [0, 0, [0]])
1354-
_, spec = cxx_pytree.tree_flatten(pytree)
1355-
self.assertExpectedInline(
1356-
repr(spec),
1357-
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')",
1331+
self.assertEqual(
1332+
repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')"
13581333
)
13591334

13601335
@parametrize(

torch/utils/_cxx_pytree.py

Lines changed: 13 additions & 14 deletions
< 493F td data-grid-cell-id="diff-90b4584778da473890663bd0cb8d36a25636f0a36b97622369dea979ba7777e9-444-443-2" data-line-anchor="diff-90b4584778da473890663bd0cb8d36a25636f0a36b97622369dea979ba7777e9R443" data-selected="false" role="gridcell" style="background-color:var(--diffBlob-additionLine-bgColor, var(--diffBlob-addition-bgColor-line));padding-right:24px" tabindex="-1" valign="top" class="focusable-grid-cell diff-text-cell right-side-diff-cell left-side">+
PyTreeSpec(*, NoneIsLeaf, namespace='torch')
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@
7171
]
7272

7373

74+
__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch")
75+
__TORCH_DICT_SESSION.__enter__() # enable globally and permanently
76+
77+
7478
T = TypeVar("T")
7579
S = TypeVar("S")
7680
U = TypeVar("U")
@@ -295,20 +299,15 @@ def tree_flatten(
295299
296300
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
297301
>>> tree_flatten(tree)
298-
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
302+
([2, 3, 4, 1, None, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch'))
299303
>>> tree_flatten(1)
300-
([1], PyTreeSpec(*, NoneIsLeaf))
304+
([1], PyTreeSpec(*, NoneIsLeaf, namespace='torch'))
301305
>>> tree_flatten(None)
302-
([None], PyTreeSpec(*, NoneIsLeaf))
303-
304-
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
305-
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
306-
if you want to keep the keys in the insertion order.
307-
306+
([None], PyTreeSpec(*, NoneIsLeaf, namespace='torch'))
308307
>>> from collections import OrderedDict
309308
>>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
310309
>>> tree_flatten(tree)
311-
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
310+
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf, namespace='torch'))
312311
313312
Args:
314313
tree (pytree): A pytree to flatten.
@@ -367,7 +366,7 @@ def tree_iter(
367366
368367
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
369368
>>> list(tree_iter(tree))
370-
[1, 2, 3, 4, None, 5]
369+
[2, 3, 4, 1, None, 5]
371370
>>> list(tree_iter(1))
372371
[1]
373372
>>> list(tree_iter(None))
@@ -402,7 +401,7 @@ def tree_leaves(
402401
403402
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
404403
>>> tree_leaves(tree)
405-
[1, 2, 3, 4, None, 5]
404+
[2, 3, 4, 1, None, 5]
406405
>>> tree_leaves(1)
407406
[1]
408407
>>> tree_leaves(None)
@@ -437,11 +436,11 @@ def tree_structure(
437436
438437
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
439438
>>> tree_structure(tree)
440-
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
439+
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')
441440
>>> tree_structure(1)
442-
PyTreeSpec(*, NoneIsLeaf)
441+
PyTreeSpec(*, NoneIsLeaf, namespace='torch')
443442
>>> tree_structure(None)
444-
PyTreeSpec(*, NoneIsLeaf)
443
445444
446445
Args:
447446
tree (pytree): A pytree to flatten.

0 commit comments

Comments
 (0)
0