8000 [pytree] add `treespec_{leaf,tuple,dict}` functions for args_spec mod… · pytorch/pytorch@f0a5430 · GitHub
[go: up one dir, main page]

Skip to content

Commit f0a5430

Browse files
committed
[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification
ghstack-source-id: f785835 Pull Request resolved: #138214
1 parent 23113a2 commit f0a5430

File tree

21 files changed

+314
-145
lines changed

21 files changed

+314
-145
lines changed

test/export/test_export.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@
8282
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
8383
from torch.testing._internal.two_tensor import TwoTensor
8484
from torch.utils._pytree import (
85-
LeafSpec,
8685
register_constant,
8786
tree_flatten,
8887
tree_map,
8988
tree_unflatten,
9089
TreeSpec,
9190
treespec_dumps,
91+
treespec_leaf,
9292
treespec_loads,
9393
)
9494

@@ -5671,7 +5671,7 @@ class MyDataClass:
56715671

56725672
dt = MyDataClass(x=3, y=4)
56735673
flat, spec = tree_flatten(dt)
5674-
self.assertTrue(spec, LeafSpec())
5674+
self.assertTrue(spec, treespec_leaf())
56755675
self.assertTrue(len(flat) == 1)
56765676

56775677
torch.export.register_dataclass(
@@ -5682,7 +5682,9 @@ class MyDataClass:
56825682
flat, spec = tree_flatten(dt)
56835683
self.assertEqual(
56845684
spec,
5685-
TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]),
5685+
TreeSpec(
5686+
MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()]
5687+
),
56865688
)
56875689
self.assertEqual(flat, [3, 4])
56885690

@@ -5715,7 +5717,7 @@ class MyOtherDataClass: # the pytree registration don't allow registering the s
57155717
TreeSpec(
57165718
MyOtherDataClass,
57175719
[["x", "y", "z"], []],
5718-
[LeafSpec(), LeafSpec(), LeafSpec()],
5720+
[treespec_leaf(), treespec_leaf(), treespec_leaf()],
57195721
),
57205722
)
57215723
self.assertEqual(flat, [3, 4, None])

test/test_pytree.py

+54-34
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ class TestEnum(enum.Enum):
6565
A = auto()
6666

6767

68-
python_leafspec = python_pytree.LeafSpec()
69-
70-
7168
class TestGenericPytree(TestCase):
7269
def test_aligned_public_apis(self):
7370
public_apis = python_pytree.__all__
@@ -197,7 +194,7 @@ def test_flatten_unflatten_leaf(self, pytree):
197194
def run_test_with_leaf(leaf):
198195
values, treespec = pytree.tree_flatten(leaf)
199196
self.assertEqual(values, [leaf])
200-
self.assertEqual(treespec, pytree.LeafSpec())
197+
self.assertEqual(treespec, pytree.treespec_leaf())
201198

202199
unflattened = pytree.tree_unflatten(values, treespec)
203200
self.assertEqual(unflattened, leaf)
@@ -215,7 +212,7 @@ def run_test_with_leaf(leaf):
215212
(
216213
python_pytree,
217214
lambda tup: python_pytree.TreeSpec(
218-
tuple, None, [python_leafspec for _ in tup]
215+
tuple, None, [python_pytree.treespec_leaf() for _ in tup]
219216
),
220217
),
221218
name="python",
@@ -254,7 +251,7 @@ def run_test(tup):
254251
(
255252
python_pytree,
256253
lambda lst: python_pytree.TreeSpec(
257-
list, None, [python_leafspec for _ in lst]
254+
list, None, [python_pytree.treespec_leaf() for _ in lst]
258255
),
259256
),
260257
name="python",
@@ -294,7 +291,7 @@ def run_test(lst):
294291
lambda dct: python_pytree.TreeSpec(
295292
dict,
296293
list(dct.keys()),
297-
[python_leafspec for _ in dct.values()],
294+
[python_pytree.treespec_leaf() for _ in dct.values()],
298295
),
299296
),
300297
name="python",
@@ -342,7 +339,7 @@ def run_test(dct):
342339
lambda odict: python_pytree.TreeSpec(
343340
OrderedDict,
344341
list(odict.keys()),
345-
[python_leafspec for _ in odict.values()],
342+
[python_pytree.treespec_leaf() for _ in odict.values()],
346343
),
347344
),
348345
name="python",
@@ -393,7 +390,7 @@ def run_test(odict):
393390
lambda ddct: python_pytree.TreeSpec(
394391
defaultdict,
395392
[ddct.default_factory, list(ddct.keys())],
396-
[python_leafspec for _ in ddct.values()],
393+
[python_pytree.treespec_leaf() for _ in ddct.values()],
397394
),
398395
),
399396
name="python",
@@ -444,7 +441,7 @@ def run_test(ddct):
444441
(
445442
python_pytree,
446443
lambda deq: python_pytree.TreeSpec(
447-
deque, deq.maxlen, [python_leafspec for _ in deq]
444+
deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq]
448445
),
449446
),
450447
name="python",
@@ -491,7 +488,7 @@ def test_flatten_unflatten_namedtuple(self, pytree):
491488
def run_test(tup):
492489
if pytree is python_pytree:
493490
expected_spec = python_pytree.TreeSpec(
494-
namedtuple, Point, [python_leafspec for _ in tup]
491+
namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup]
495492
)
496493
else:
497494
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
@@ -877,16 +874,16 @@ def test_import_pytree_doesnt_import_optree(self):
877874

878875
def test_treespec_equality(self):
879876
self.assertEqual(
880-
python_pytree.LeafSpec(),
881-
python_pytree.LeafSpec(),
877+
python_pytree.treespec_leaf(),
878+
python_pytree.treespec_leaf(),
882879
)
883880
self.assertEqual(
884881
python_pytree.TreeSpec(list, None, []),
885882
python_pytree.TreeSpec(list, None, []),
886883
)
887884
self.assertEqual(
888-
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
889-
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
885+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
886+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
890887
)
891888
self.assertFalse(
892889
python_pytree.TreeSpec(tuple, None, [])
@@ -921,24 +918,32 @@ def test_treespec_repr(self):
921918
# python_pytree.tree_structure({})
922919
python_pytree.TreeSpec(dict, [], []),
923920
# python_pytree.tree_structure([0])
924-
python_pytree.TreeSpec(list, None, [python_leafspec]),
921+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
925922
# python_pytree.tree_structure([0, 1])
926923
python_pytree.TreeSpec(
927924
list,
928925
None,
929-
[python_leafspec, python_leafspec],
926+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
930927
),
931928
# python_pytree.tree_structure((0, 1, 2))
932929
python_pytree.TreeSpec(
933930
tuple,
934931
None,
935-
[python_leafspec, python_leafspec, python_leafspec],
932+
[
933+
python_pytree.treespec_leaf(),
934+
python_pytree.treespec_leaf(),
935+
python_pytree.treespec_leaf(),
936+
],
936937
),
937938
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
938939
python_pytree.TreeSpec(
939940
dict,
940941
["a", "b", "c"],
941-
[python_leafspec, python_leafspec, python_leafspec],
942+
[
943+
python_pytree.treespec_leaf(),
944+
python_pytree.treespec_leaf(),
945+
python_pytree.treespec_leaf(),
946+
],
942947
),
943948
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
944949
python_pytree.TreeSpec(
@@ -948,13 +953,17 @@ def test_treespec_repr(self):
948953
python_pytree.TreeSpec(
949954
tuple,
950955
None,
951-
[python_leafspec, python_leafspec],
956+
[python_pytree.treespec_leaf(), python_pytree.treespec_l 10000 eaf()],
952957
),
953-
python_leafspec,
958+
python_pytree.treespec_leaf(),
954959
python_pytree.TreeSpec(
955960
dict,
956961
["a", "b", "c"],
957-
[python_leafspec, python_leafspec, python_leafspec],
962+
[
963+
python_pytree.treespec_leaf(),
964+
python_pytree.treespec_leaf(),
965+
python_pytree.treespec_leaf(),
966+
],
958967
),
959968
],
960969
),
@@ -967,12 +976,15 @@ def test_treespec_repr(self):
967976
tuple,
968977
None,
969978
[
970-
python_leafspec,
971-
python_leafspec,
979+
python_pytree.treespec_leaf(),
980+
python_pytree.treespec_leaf(),
972981
python_pytree.TreeSpec(
973982
list,
974983
None,
975-
[python_leafspec, python_leafspec],
984+
[
985+
python_pytree.treespec_leaf(),
986+
python_pytree.treespec_leaf(),
987+
],
976988
),
977989
],
978990
),
@@ -986,12 +998,12 @@ def test_treespec_repr(self):
986998
python_pytree.TreeSpec(
987999
list,
9881000
None,
989-
[python_leafspec, python_leafspec],
1001+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
9901002
),
9911003
python_pytree.TreeSpec(
9921004
list,
9931005
None,
994-
[python_leafspec, python_leafspec],
1006+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
9951007
),
9961008
python_pytree.TreeSpec(dict, [], []),
9971009
],
@@ -1000,7 +1012,7 @@ def test_treespec_repr(self):
10001012
python_pytree.TreeSpec(
10011013
python_pytree.structseq,
10021014
torch.return_types.sort,
1003-
[python_leafspec, python_leafspec],
1015+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
10041016
),
10051017
],
10061018
)
@@ -1026,7 +1038,7 @@ def test_pytree_serialize_defaultdict_enum(self):
10261038
list,
10271039
None,
10281040
[
1029-
python_leafspec,
1041+
python_pytree.treespec_leaf(),
10301042
],
10311043
),
10321044
],
@@ -1035,7 +1047,7 @@ def test_pytree_serialize_defaultdict_enum(self):
10351047
self.assertIsInstance(serialized_spec, str)
10361048

10371049
def test_pytree_serialize_enum(self):
1038-
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec])
1050+
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()])
10391051

10401052
serialized_spec = python_pytree.treespec_dumps(spec)
10411053
self.assertIsInstance(serialized_spec, str)
@@ -1198,12 +1210,20 @@ def test_saved_serialized(self):
11981210
OrderedDict,
11991211
[1, 2, 3],
12001212
[
1201-
python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]),
1202-
python_leafspec,
1213+
python_pytree.TreeSpec(
1214+
tuple,
1215+
None,
1216+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
1217+
),
1218+
python_pytree.treespec_leaf(),
12031219
python_pytree.TreeSpec(
12041220
dict,
12051221
[4, 5, 6],
1206-
[python_leafspec, python_leafspec, python_leafspec],
1222+
[
1223+
python_pytree.treespec_leaf(),
1224+
python_pytree.treespec_leaf(),
1225+
python_pytree.treespec_leaf(),
1226+
],
12071227
),
12081228
],
12091229
)
@@ -1488,7 +1508,7 @@ def setUp(self):
14881508
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
14891509

14901510
def test_treespec_equality(self):
1491-
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
1511+
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
14921512

14931513
def test_treespec_repr(self):
14941514
# Check that it looks sane

torch/_dynamo/polyfills/pytree.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
if TYPE_CHECKING:
1919
import builtins
20-
from collections.abc import Iterable
20+
from collections.abc import Iterable, Mapping
2121
from typing_extensions import Self
2222

2323
from torch.utils._cxx_pytree import PyTree
@@ -323,6 +323,61 @@ def unflatten(self, leaves: Iterable[Any]) -> PyTree:
323323
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
324324
return isinstance(obj, PyTreeSpec)
325325

326+
@substitute_in_graph( # type: ignore[arg-type]
327+
cxx_pytree.treespec_leaf,
328+
# We need to disable constant folding here because we want the function to reference the
329+
# PyTreeSpec class defined above, not the one in the C++ module.
330+
can_constant_fold_through=False,
331+
)
332+
def treespec_leaf() -> PyTreeSpec:
333+
return _LEAF_SPEC
334+
335+
@substitute_in_graph( # type: ignore[arg-type]
336+
cxx_pytree.treespec_tuple,
337+
# We need to disable constant folding here because we want the function to reference the
338+
# PyTreeSpec class defined above, not the one in the C++ module.
339+
can_constant_fold_through=False,
340+
)
341+
def treespec_tuple(iterable: Iterable[PyTreeSpec] = (), /) -> PyTreeSpec:
342+
children = tuple(iterable)
343+
if any(not _is_pytreespec_instance(child) for child in children):
344+
raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
345+
handler = optree.register_pytree_node.get(tuple, namespace="torch") # type: ignore[attr-defined]
346+
return PyTreeSpec(
347+
tuple(children),
348+
tuple,
349+
None,
350+
tuple(range(len(children))),
351+
handler.unflatten_func,
352+
)
353+
354+
@substitute_in_graph( # type: ignore[arg-type]
355+
cxx_pytree.treespec_dict,
356+
# We need to disable constant folding here because we want the function to reference the
357+
# PyTreeSpec class defined above, not the one in the C++ module.
358+
can_constant_fold_through=False,
359+
)
360+
def treespec_dict(
361+
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
362+
/,
363+
**kwargs: PyTreeSpec,
364+
) -> PyTreeSpec:
365+
dct = dict(mapping, **kwargs)
366+
if any(not _is_pytreespec_instance(child) for child in dct.values()):
367+
raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
368+
369+
(
370+
children,
371+
metadata,
372+
entries,
373+
unflatten_func,
374+
) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
375+
dct, # type: ignore[arg-type]
376+
none_is_leaf=True,
377+
namespace="torch",
378+
)
379+
return PyTreeSpec(tuple(children), dict, metadata, entries, unflatten_func) # type: ignore[arg-type]
380+
326381
@substitute_in_graph( # type: ignore[arg-type]
327382
cxx_pytree.tree_flatten,
328383
# We need to disable constant folding here because we want the function to reference the

torch/_dynamo/variables/builder.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -3177,9 +3177,7 @@ def create(tx: "InstructionTranslator", value) -> VariableTracker:
31773177
pass # failthrough to unimplemented branch
31783178
elif isinstance(value, torch.fx.graph_module.GraphModule):
31793179
return SourcelessGraphModuleVariable(value)
3180-
elif isinstance(
3181-
value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
3182-
):
3180+
elif isinstance(value, torch.utils._pytree.TreeSpec):
31833181
return UserDefinedObjectVariable(value)
31843182
elif PlacementVariable.is_placement(value):
31853183
return PlacementVariable(value)

torch/_functorch/_aot_autograd/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def set(self, spec: pytree.TreeSpec) -> None:
150150
assert spec is not None
151151
self.spec: pytree.TreeSpec = spec
152152
if self.spec.type in {tuple, list} and all(
153-
child.is_leaf() for child in spec.children_specs
153+
child.is_leaf() for child in spec.children()
154154
):
155155
self.is_simple = True
156156
if self.spec.is_leaf():

0 commit comments

Comments
 (0)
0