8000 [pytree] add `treespec_{leaf,tuple,dict}` functions for args_spec mod… · pytorch/pytorch@9970bed · GitHub
[go: up one dir, main page]

Skip to content

Commit 9970bed

Browse files
committed
[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification
ghstack-source-id: 5d64ac5 Pull Request resolved: #138214
1 parent d5eea33 commit 9970bed

File tree

21 files changed

+314
-145
lines changed

21 files changed

+314
-145
lines changed

test/export/test_export.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,13 @@
8484
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
8585
from torch.testing._internal.two_tensor import TwoTensor
8686
from torch.utils._pytree import (
87-
LeafSpec,
8887
register_constant,
8988
tree_flatten,
9089
tree_map,
9190
tree_unflatten,
9291
TreeSpec,
9392
treespec_dumps,
93+
treespec_leaf,
9494
treespec_loads,
9595
)
9696

@@ -5895,7 +5895,7 @@ class MyDataClass:
58955895

58965896
dt = MyDataClass(x=3, y=4)
58975897
flat, spec = tree_flatten(dt)
5898-
self.assertTrue(spec, LeafSpec())
5898+
self.assertTrue(spec, treespec_leaf())
58995899
self.assertTrue(len(flat) == 1)
59005900

59015901
torch.export.register_dataclass(
@@ -5906,7 +5906,9 @@ class MyDataClass:
59065906
flat, spec = tree_flatten(dt)
59075907
self.assertEqual(
59085908
spec,
5909-
TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]),
5909+
TreeSpec(
5910+
MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()]
5911+
),
59105912
)
59115913
self.assertEqual(flat, [3, 4])
59125914

@@ -5939,7 +5941,7 @@ class MyOtherDataClass: # the pytree registration don't allow registering the s
59395941
TreeSpec(
59405942
MyOtherDataClass,
59415943
[["x", "y", "z"], []],
5942-
[LeafSpec(), LeafSpec(), LeafSpec()],
5944+
[treespec_leaf(), treespec_leaf(), treespec_leaf()],
59435945
),
59445946
)
59455947
self.assertEqual(flat, [3, 4, None])

test/test_pytree.py

+54-34
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ class TestEnum(enum.Enum):
6565
A = auto()
6666

6767

68-
python_leafspec = python_pytree.LeafSpec()
69-
70-
7168
class TestGenericPytree(TestCase):
7269
def test_aligned_public_apis(self):
7370
public_apis = python_pytree.__all__
@@ -197,7 +194,7 @@ def test_flatten_unflatten_leaf(self, pytree):
197194
def run_test_with_leaf(leaf):
198195
values, treespec = pytree.tree_flatten(leaf)
199196
self.assertEqual(values, [leaf])
200-
self.assertEqual(treespec, pytree.LeafSpec())
197+
self.assertEqual(treespec, pytree.treespec_leaf())
201198

202199
unflattened = pytree.tree_unflatten(values, treespec)
203200
self.assertEqual(unflattened, leaf)
@@ -215,7 +212,7 @@ def run_test_with_leaf(leaf):
215212
(
216213
python_pytree,
217214
lambda tup: python_pytree.TreeSpec(
218-
tuple, None, [python_leafspec for _ in tup]
215+
tuple, None, [python_pytree.treespec_leaf() for _ in tup]
219216
),
220217
),
221218
name="python",
@@ -250,7 +247,7 @@ def run_test(tup):
250247
(
251248
python_pytree,
252249
lambda lst: python_pytree.TreeSpec(
253-
list, None, [python_leafspec for _ in lst]
250+
list, None, [python_pytree.treespec_leaf() for _ in lst]
254251
),
255252
),
256253
name="python",
@@ -286,7 +283,7 @@ def run_test(lst):
286283
lambda dct: python_pytree.TreeSpec(
287284
dict,
288285
list(dct.keys()),
289-
[python_leafspec for _ in dct.values()],
286+
[python_pytree.treespec_leaf() for _ in dct.values()],
290287
),
291288
),
292289
name="python",
@@ -327,7 +324,7 @@ def run_test(dct):
327324
lambda odict: python_pytree.TreeSpec(
328325
OrderedDict,
329326
list(odict.keys()),
330-
[python_leafspec for _ in odict.values()],
327+
[python_pytree.treespec_leaf() for _ in odict.values()],
331328
),
332329
),
333330
name="python",
@@ -371,7 +368,7 @@ def run_test(odict):
371368
lambda ddct: python_pytree.TreeSpec(
372369
defaultdict,
373370
[ddct.default_factory, list(ddct.keys())],
374-
[python_leafspec for _ in ddct.values()],
371+
[python_pytree.treespec_leaf() for _ in ddct.values()],
375372
),
376373
),
377374
name="python",
@@ -413,7 +410,7 @@ def run_test(ddct):
413410
(
414411
python_pytree,
415412
lambda deq: python_pytree.TreeSpec(
416-
deque, deq.maxlen, [python_leafspec for _ in deq]
413+
deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq]
417414
),
418415
),
419416
name="python",
@@ -453,7 +450,7 @@ def test_flatten_unflatten_namedtuple(self, pytree):
453450
def run_test(tup):
454451
if pytree is python_pytree:
455452
expected_spec = python_pytree.TreeSpec(
456-
namedtuple, Point, [python_leafspec for _ in tup]
453+
namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup]
457454
)
458455
else:
459456
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
@@ -839,16 +836,16 @@ def test_import_pytree_doesnt_import_optree(self):
839836

840837
def test_treespec_equality(self):
841838
self.assertEqual(
842-
python_pytree.LeafSpec(),
843-
python_pytree.LeafSpec(),
839+
python_pytree.treespec_leaf(),
840+
python_pytree.treespec_leaf(),
844841
)
845842
self.assertEqual(
846843
python_pytree.TreeSpec(list, None, []),
847844
python_pytree.TreeSpec(list, None, []),
848845
)
849846
self.assertEqual(
850-
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
851-
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
847+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
848+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
852849
)
853850
self.assertFalse(
854851
python_pytree.TreeSpec(tuple, None, [])
@@ -883,24 +880,32 @@ def test_treespec_repr(self):
883880
# python_pytree.tree_structure({})
884881
python_pytree.TreeSpec(dict, [], []),
885882
# python_pytree.tree_structure([0])
886-
python_pytree.TreeSpec(list, None, [python_leafspec]),
883+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
887884
# python_pytree.tree_structure([0, 1])
888885
python_pytree.TreeSpec(
889886
list,
890887
None,
891-
[python_leafspec, python_leafspec],
888+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
892889
),
893890
# python_pytree.tree_structure((0, 1, 2))
894891
python_pytree.TreeSpec(
895892
tuple,
896893
None,
897-
[python_leafspec, python_leafspec, python_leafspec],
894+
[
895+
python_pytree.treespec_leaf(),
896+
python_pytree.treespec_leaf(),
897+
python_pytree.treespec_leaf(),
898+
],
898899
),
899900
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
900901
python_pytree.TreeSpec(
901902
dict,
902903
["a", "b", "c"],
903-
[python_leafspec, python_leafspec, python_leafspec],
904+
[
905+
python_pytree.treespec_leaf(),
906+
python_pytree.treespec_leaf(),
907+
python_pytree.treespec_leaf(),
908+
],
904909
),
905910
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
906911
python_pytree.TreeSpec(
@@ -910,13 +915,17 @@ def test_treespec_repr(self):
910915
python_pytree.TreeSpec(
911916
tuple,
912917
None,
913-
[python_leafspec, python_leafspec],
918+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
914919
),
915-
python_leafspec,
920+
python_pytree.treespec_leaf(),
916921
python_pytree.TreeSpec(
917922
dict,
918923
["a", "b", "c"],
919-
[python_leafspec, python_leafspec, python_leafspec],
924+
[
925+
python_pytree.treespec_leaf(),
926+
python_pytree.treespec_leaf(),
927+
python_pytree.treespec_leaf(),
928+
],
920929
),
921930
],
922931
),
@@ -929,12 +938,15 @@ def test_treespec_repr(self):
929938
tuple,
930939
None,
931940
[
932-
python_leafspec,
933-
python_leafspec,
941+
python_pytree.treespec_leaf(),
942+
python_pytree.treespec_leaf(),
934943
python_pytree.TreeSpec(
935944
list,
936945
None,
937-
[python_leafspec, python_leafspec],
946+
[
947+
python_pytree.treespec_leaf(),
948+
python_pytree.treespec_leaf(),
949+
],
938950
),
939951
],
940952
),
@@ -948,12 +960,12 @@ def test_treespec_repr(self):
948960
python_pytree.TreeSpec(
949961
list,
950962
None,
951-
[python_leafspec, python_leafspec],
963+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
952964
),
953965
python_pytree.TreeSpec(
954966
list,
955967
None,
956-
[python_leafspec, python_leafspec],
968+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
957969
),
958970
python_pytree.TreeSpec(dict, [], []),
959971
],
@@ -962,7 +974,7 @@ def test_treespec_repr(self):
962974
python_pytree.TreeSpec(
963975
python_pytree.structseq,
964976
torch.return_types.sort,
965-
[python_leafspec, python_leafspec],
977+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
966978
),
967979
],
968980
)
@@ -988,7 +1000,7 @@ def test_pytree_serialize_defaultdict_enum(self):
9881000
list,
9891001
None,
9901002
[
991-
python_leafspec,
1003+
python_pytree.treespec_leaf(),
9921004
],
9931005
),
9941006
],
@@ -997,7 +1009,7 @@ def test_pytree_serialize_defaultdict_enum(self):
9971009
self.assertIsInstance(serialized_spec, str)
9981010

9991011
def test_pytree_serialize_enum(self):
1000-
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec])
1012+
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()])
10011013

10021014
serialized_spec = python_pytree.treespec_dumps(spec)
10031015
self.assertIsInstance(serialized_spec, str)
@@ -1160,12 +1172,20 @@ def test_saved_serialized(self):
11601172
OrderedDict,
11611173
[1, 2, 3],
11621174
[
1163-
python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]),
1164-
python_leafspec,
1175+
python_pytree.TreeSpec(
1176+
tuple,
1177+
None,
1178+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
1179+
),
1180+
python_pytree.treespec_leaf(),
11651181
python_pytree.TreeSpec(
11661182
dict,
11671183
[4, 5, 6],
1168-
[python_leafspec, python_leafspec, python_leafspec],
1184+
[
1185+
python_pytree.treespec_leaf(),
1186+
python_pytree.treespec_leaf(),
1187+
python_pytree.treespec_leaf(),
1188+
],
11691189
),
11701190
],
11711191
)
@@ -1450,7 +1470,7 @@ def setUp(self):
14501470
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
14511471

14521472
def test_treespec_equality(self):
1453-
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
1473+
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
14541474

14551475
def test_treespec_repr(self):
14561476
# Check that it looks sane

torch/_dynamo/polyfills/pytree.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
if TYPE_CHECKING:
1919
import builtins
20-
from collections.abc import Iterable
20+
from collections.abc import Iterable, Mapping
2121
from typing_extensions import Self
2222

2323

@@ -324,6 +324,61 @@ def unflatten(self, leaves: Iterable[Any]) -> PyTree:
324324
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
325325
return isinstance(obj, PyTreeSpec)
326326

327+
@substitute_in_graph( # type: ignore[arg-type]
328+
cxx_pytree.treespec_leaf,
329+
# We need to disable constant folding here because we want the function to reference the
330+
# PyTreeSpec class defined above, not the one in the C++ module.
331+
can_constant_fold_through=False,
332+
)
333+
def treespec_leaf() -> PyTreeSpec:
334+
return _LEAF_SPEC
335+
336+
@substitute_in_graph( # type: ignore[arg-type]
337+
cxx_pytree.treespec_tuple,
338+
# We need to disable constant folding here because we want the function to reference the
339+
# PyTreeSpec class defined above, not the one in the C++ module.
340+
can_constant_fold_through=False,
341+
)
342+
def treespec_tuple(iterable: Iterable[PyTreeSpec] = (), /) -> PyTreeSpec:
343+
children = tuple(iterable)
344+
if any(not _is_pytreespec_instance(child) for child in children):
345+
raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
346+
handler = optree.register_pytree_node.get(tuple, namespace="torch") # type: ignore[attr-defined]
347+
return PyTreeSpec(
348+
tuple(children),
349+
tuple,
350+
None,
351+
tuple(range(len(children))),
352+
handler.unflatten_func,
353+
)
354+
355+
@substitute_in_graph( # type: ignore[arg-type]
356+
cxx_pytree.treespec_dict,
357+
# We need to disable constant folding here because we want the function to reference the
358+
# PyTreeSpec class defined above, not the one in the C++ module.
359+
can_constant_fold_through=False,
360+
)
361+
def treespec_dict(
362+
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
363+
/,
364+
**kwargs: PyTreeSpec,
365+
) -> PyTreeSpec:
366+
dct = dict(mapping, **kwargs)
367+
if any(not _is_pytreespec_instance(child) for child in dct.values()):
368+
raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
369+
370+
(
371+
children,
372+
metadata,
373+
entries,
374+
unflatten_func,
375+
) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
376+
dct, # type: ignore[arg-type]
377+
none_is_leaf=True,
378+
namespace="torch",
379+
)
380+
return PyTreeSpec(tuple(children), dict, metadata, entries, unflatten_func) # type: ignore[arg-type]
381+
327382
@substitute_in_graph( # type: ignore[arg-type]
328383
cxx_pytree.tree_flatten,
329384
# We need to disable constant folding here because we want the function to reference the

torch/_dynamo/variables/builder.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -3344,9 +3344,7 @@ def create(tx: "InstructionTranslator", value) -> VariableTracker:
33443344
pass # failthrough to unimplemented branch
33453345
elif isinstance(value, torch.fx.graph_module.GraphModule):
33463346
return SourcelessGraphModuleVariable(value)
3347-
elif isinstance(
3348-
value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
3349-
):
3347+
elif isinstance(value, torch.utils._pytree.TreeSpec):
33503348
return UserDefinedObjectVariable(value)
33513349
elif PlacementVariable.is_placement(value):
33523350
return PlacementVariable(value)

torch/_functorch/_aot_autograd/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def set(self, spec: pytree.TreeSpec) -> None:
150150
assert spec is not None
151151
self.spec: pytree.TreeSpec = spec
152152
if self.spec.type in {tuple, list} and all(
153-
child.is_leaf() for child in spec.children_specs
153+
child.is_leaf() for child in spec.children()
154154
):
155155
self.is_simple = True
156156
if self.spec.is_leaf():

0 commit comments

Comments
 (0)
0