8000 [pytree] support collections.deque type for Python pytree by XuehaiPan · Pull Request #113256 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[pytree] support collections.deque type for Python pytree #113256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
de231ce
[pytree] support collections.deque type for Python pytree
XuehaiPan Nov 8, 2023
99a4fd0
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 8, 2023
e4efbc0
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 8, 2023
0f128a6
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 8, 2023
5b01759
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 8, 2023
272c142
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 8, 2023
3ccc1c9
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 8, 2023
0b44a96
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 22, 2023
90e1588
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 28, 2023
d62eb82
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 30, 2023
9057929
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 30, 2023
9f3c06b
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 30, 2023
b45ef95
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 30, 2023
be28b1d
Update on "[pytree] support collections.deque type for Python pytree"
XuehaiPan Nov 30, 2023
7104738
Update on "[pytre 8000 e] support collections.deque type for Python pytree"
XuehaiPan Nov 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 66 additions & 24 deletions test/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import re
import unittest
from collections import defaultdict, namedtuple, OrderedDict, UserDict
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict

import torch
import torch.utils._cxx_pytree as cxx_pytree
Expand Down Expand Up @@ -203,13 +203,13 @@ def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn):
def run_test(tup):
expected_spec = gen_expected_fn(tup)
values, treespec = pytree_impl.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertIsInstance(values, list)
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)

unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, tuple))
self.assertIsInstance(unflattened, tuple)

run_test(())
run_test((1.0,))
Expand Down Expand Up @@ -238,13 +238,13 @@ def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn):
def run_test(lst):
expected_spec = gen_expected_fn(lst)
values, treespec = pytree_impl.tree_flatten(lst)
self.assertTrue(isinstance(values, list))
self.assertIsInstance(values, list)
self.assertEqual(values, lst)
self.assertEqual(treespec, expected_spec)

unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, lst)
self.assertTrue(isinstance(unflattened, list))
self.assertIsInstance(unflattened, list)

run_test([])
run_test([1.0, 2])
Expand Down Expand Up @@ -277,13 +277,13 @@ def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn):
def run_test(dct):
expected_spec = gen_expected_fn(dct)
values, treespec = pytree_impl.tree_flatten(dct)
self.assertTrue(isinstance(values, list))
self.assertIsInstance(values, list)
self.assertEqual(values, list(dct.values()))
self.assertEqual(treespec, expected_spec)

unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, dct)
self.assertTrue(isinstance(unflattened, dict))
self.assertIsInstance(unflattened, dict)

run_test({})
run_test({"a": 1})
Expand Down Expand Up @@ -320,13 +320,13 @@ def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn):
def run_test(odict):
expected_spec = gen_expected_fn(odict)
values, treespec = pytree_impl.tree_flatten(odict)
self.assertTrue(isinstance(values, list))
self.assertIsInstance(values, list)
self.assertEqual(values, list(odict.values()))
self.assertEqual(treespec, expected_spec)

unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, odict)
self.assertTrue(isinstance(unflattened, OrderedDict))
self.assertIsInstance(unflattened, OrderedDict)

od = OrderedDict()
run_test(od)
Expand Down Expand Up @@ -364,21 +364,61 @@ def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn):
def run_test(ddct):
expected_spec = gen_expected_fn(ddct)
values, treespec = pytree_impl.tree_flatten(ddct)
self.assertTrue(isinstance(values, list))
self.assertIsInstance(values, list)
self.assertEqual(values, list(ddct.values()))
self.assertEqual(treespec, expected_spec)

unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, ddct)
self.assertEqual(unflattened.default_factory, ddct.default_factory)
self.assertTrue(isinstance(unflattened, defaultdict))
self.assertIsInstance(unflattened, defaultdict)

run_test(defaultdict(list, {}))
run_test(defaultdict(int, {"a": 1}))
run_test(defaultdict(int, {"abcdefg": torch.randn(2, 3)}))
run_test(defaultdict(int, {1: torch.randn(2, 3)}))
run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)}))

@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda deq: py_pytree.TreeSpec(
deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq]
),
),
name="py",
),
subtest(
(
cxx_pytree,
lambda deq: cxx_pytree.tree_structure(
deque(deq, maxlen=deq.maxlen)
),
),
name="cxx",
),
],
)
def test_flatten_unflatten_deque(self, pytree_impl, gen_expected_fn):
def run_test(deq):
expected_spec = gen_expected_fn(deq)
values, treespec = pytree_impl.tree_flatten(deq)
self.assertIsInstance(values, list)
self.assertEqual(values, list(deq))
self.assertEqual(treespec, expected_spec)

unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, deq)
self.assertEqual(unflattened.maxlen, deq.maxlen)
self.assertIsInstance(unflattened, deque)

run_test(deque([]))
run_test(deque([1.0, 2]))
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))

@parametrize(
"pytree_impl",
[
Expand All @@ -397,13 +437,13 @@ def run_test(tup):
else:
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
values, treespec = pytree_impl.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertIsInstance(values, list)
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)

unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, Point))
self.assertIsInstance(unflattened, Point)

run_test(Point(1.0, 2))
run_test(Point(torch.tensor(1.0), 2))
Expand All @@ -429,7 +469,7 @@ def test_flatten_unflatten_return_type(self, pytree_impl, op):
values, spec = pytree_impl.tree_flatten(expected)
# Check that values is actually List[Tensor] and not (ReturnType(...),)
for value in values:
self.assertTrue(isinstance(value, torch.Tensor))
self.assertIsInstance(value, torch.Tensor)
result = pytree_impl.tree_unflatten(values, spec)

self.assertEqual(type(result), type(expected))
Expand All @@ -445,7 +485,7 @@ def test_flatten_unflatten_return_type(self, pytree_impl, op):
def test_flatten_unflatten_nested(self, pytree_impl):
def run_test(pytree):
values, treespec = pytree_impl.tree_flatten(pytree)
self.assertTrue(isinstance(values, list))
self.assertIsInstance(values, list)
self.assertEqual(len(values), treespec.num_leaves)

# NB: python basic data structures (dict list tuple) all have
Expand Down Expand Up @@ -607,15 +647,17 @@ def __init__(self, x, y):
)

def test_treespec_equality(self):
self.assertTrue(
py_pytree.LeafSpec() == py_pytree.LeafSpec(),
self.assertEqual(
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
)
self.assertTrue(
py_pytree.TreeSpec(list, None, []) == py_pytree.TreeSpec(list, None, []),
self.assertEqual(
py_pytree.TreeSpec(list, None, []),
py_pytree.TreeSpec(list, None, []),
)
self.assertTrue(
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()])
== py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
self.assertEqual(
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
)
self.assertFalse(
py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []),
Expand Down Expand Up @@ -830,7 +872,7 @@ def __init__(self, x, y):
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
serialized_spec = py_pytree.treespec_dumps(spec, 1)
self.assertTrue("moo" in serialized_spec)
self.assertIn("moo", serialized_spec)
roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)

Expand Down Expand Up @@ -942,7 +984,7 @@ def test_saved_serialized(self):

class TestCxxPytree(TestCase):
def test_treespec_equality(self):
self.assertTrue(cxx_pytree.LeafSpec() == cxx_pytree.LeafSpec())
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())

@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
def test_treespec_repr(self):
Expand Down
15 changes: 15 additions & 0 deletions torch/utils/_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Callable,
cast,
DefaultDict,
Deque,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -385,6 +386,14 @@ def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
return [default_factory, dict_context]


def _deque_flatten(deq: Deque[Any]) -> Tuple[List[Any], Context]:
return list(deq), deq.maxlen


def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]:
return deque(values, maxlen=context)


_private_register_pytree_node(
tuple,
_tuple_flatten,
Expand Down Expand Up @@ -425,6 +434,12 @@ def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
to_dumpable_context=_defaultdict_serialize,
from_dumpable_context=_defaultdict_deserialize,
)
_private_register_pytree_node(
deque,
_deque_flatten,
_deque_unflatten,
serialized_type_name="collections.deque",
)


# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
Expand Down
0