diff --git a/.lintrunner.toml b/.lintrunner.toml index dbc9fb7b3f70d..e40ecac536318 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -177,6 +177,7 @@ include_patterns = [ 'benchmarks/instruction_counts/**/*.py', 'tools/**/*.py', 'torchgen/**/*.py', + 'torch/utils/pytree/__init__.py', 'torch/utils/_pytree.py', 'torch/utils/_cxx_pytree.py', 'torch/utils/benchmark/utils/common.py', diff --git a/docs/source/pytorch-api.md b/docs/source/pytorch-api.md index 1083354f3b3ca..62cd5b2d955b5 100644 --- a/docs/source/pytorch-api.md +++ b/docs/source/pytorch-api.md @@ -67,6 +67,7 @@ sparse storage torch.testing torch.utils +torch.utils.pytree torch.utils.benchmark torch.utils.bottleneck torch.utils.checkpoint diff --git a/docs/source/torch.utils.pytree.rst b/docs/source/torch.utils.pytree.rst new file mode 100644 index 0000000000000..b84faa4f9ef84 --- /dev/null +++ b/docs/source/torch.utils.pytree.rst @@ -0,0 +1,7 @@ +torch.utils.pytree +================== + +.. currentmodule:: torch.utils.pytree + +.. automodule:: torch.utils.pytree + :members: diff --git a/mypy-strict.ini b/mypy-strict.ini index 2feea92cb8c05..2e92fad971fb9 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -29,6 +29,7 @@ files = benchmarks/instruction_counts, tools, torch/profiler/_memory_profiler.py, + torch/utils/pytree/__init__.py, torch/utils/_pytree.py, torch/utils/_cxx_pytree.py, torch/utils/benchmark/utils/common.py, diff --git a/test/dynamo/test_flat_apply.py b/test/dynamo/test_flat_apply.py index 8e5d945299186..e26bfee0bc47e 100644 --- a/test/dynamo/test_flat_apply.py +++ b/test/dynamo/test_flat_apply.py @@ -147,8 +147,8 @@ def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"): t: "f32[10]" = l_x_ + l_y_ - trace_point_tensor_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_spec - trace_point_tensor_input_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_input_spec + trace_point_tensor_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_spec + trace_point_tensor_input_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_input_spec res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None return (res,) """, # NOQA: B950 diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 1bc1904034300..3da961fd6a4fc 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -37,6 +37,7 @@ import torch.onnx.operators import torch.utils._pytree as python_pytree import torch.utils.cpp_extension +import torch.utils.pytree as generic_pytree from torch import Tensor from torch._C import FileCheck from torch._dynamo import allow_in_graph @@ -83,12 +84,15 @@ ) from torch.testing._internal.common_utils import ( freeze_rng_state, + instantiate_parametrized_tests, IS_FBCODE, + parametrize, scoped_load_inline, set_default_dtype, skipIfHpu, skipIfNNModuleInlined, skipIfWindows, + subtest, TEST_HPU, wrapDeterministicFlagAPITest, ) @@ -96,11 +100,22 @@ from torch.testing._internal.logging_utils import logs_to_string +pytree_modules = { + "generic": generic_pytree, + "python": python_pytree, +} if python_pytree._cxx_pytree_dynamo_traceable: import torch.utils._cxx_pytree as cxx_pytree + + pytree_modules["cxx"] = cxx_pytree else: cxx_pytree = None +parametrize_pytree_module = parametrize( + "pytree", + [subtest(module, name=name) for name, module in pytree_modules.items()], +) + MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"]) T = typing.TypeVar("T") @@ -8698,71 +8713,6 @@ def fn(): opt = torch.compile(fn, backend="eager") opt() - def test_tracing_py_tree(self): - def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) - res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) - - xs = [torch.tensor(i) for i in range(3)] - - counter = CompileCounter() - torch.compile(fn, backend=counter, fullgraph=True)(xs) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 3) - - def test_tracing_nested_py_tree(self): - def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) - res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) - - xs = [torch.tensor(i) for i in range(3)] - xsl = [xs, xs, xs, xs] - - counter = CompileCounter() - comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) - real_out = fn(xsl) - self.assertEqual(comp_out, real_out) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 12) - - def test_tracing_nested_py_tree_tuples(self): - def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) - res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) - - xs = [torch.tensor(i) for i in range(3)] - xsl = (xs, xs, xs, xs) - - counter = CompileCounter() - comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) - real_out = fn(xsl) - self.assertEqual(comp_out, real_out) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 12) - - def test_tracing_nested_py_tree_dicts(self): - def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) - res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) - - xs = [torch.tensor(i) for i in range(3)] - xsl = { - "a": xs, - "b": xs, - "c": xs, - } - - counter = CompileCounter() - comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) - real_out = fn(xsl) - self.assertEqual(comp_out, real_out) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 9) - def test_dynamic_one_hot(self): def fn(x): x = x + 1 @@ -8779,28 +8729,6 @@ def fn(x): self.assertEqual(counter.frame_count, 2) self.assertEqual(counter.op_count, 2) - def test_tracing_nested_py_tree_mixed_all(self): - def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) - res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) - - xs = [torch.tensor(i) for i in range(3)] - xsa = (xs, xs) - xsb = {"aa": xsa, "ab": xs} - xsl = { - "a": xs, - "b": xsa, - "c": xsb, - } - - counter = CompileCounter() - comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) - real_out = fn(xsl) - self.assertEqual(comp_out, real_out) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 18) - def test_any_all_symnode(self): cnt = CompileCounter() @@ -8827,46 +8755,6 @@ def fn(x): self.assertEqual(fn(y3), y3 - 3) self.assertEqual(cnt.frame_count, 2) - def test_tracing_py_tree_tensor_subclass(self): - from torch.testing._internal.two_tensor import TwoTensor - from torch.utils.checkpoint import checkpoint - - def fn(xs): - nested_xs = [[xs]] - flat_xs, spec = python_pytree.tree_flatten(xs) - return flat_xs[0].clone() - - # use checkpoint to trigger a "sourceless" tensor subclass - def checkpoint_fn(xs): - return checkpoint(fn, xs, use_reentrant=True) - - xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2)) - - counter = CompileCounter() - torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 2) - - def test_tracing_tree_map_only(self): - def fn(xs): - def mapper(x): - return x.clone() - - y = python_pytree.tree_map_only(torch.Tensor, mapper, xs) - return y - - xs = [torch.tensor(i) for i in range(3)] + ["hi"] - xsa = (xs, xs) - xsb = {"aa": xsa, "ab": xs} - - counter = CompileCounter() - comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb) - real_out = fn(xsb) - - self.assertEqual(comp_out, real_out) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 9) - @torch._dynamo.config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True ) @@ -10242,139 +10130,6 @@ def fn(x, y): self.assertEqual(actual, expected) - def test_pytree_tree_leaves(self): - implemtations = [("python", python_pytree)] - if cxx_pytree is not None: - implemtations.append(("cxx", cxx_pytree)) - - for name, module in implemtations: - with self.subTest(f"pytree implement: {name}"): - - def fn(x): - tree = { - "a": [x, x - 1], - "b": x + 2, - "c": ( - x, - 3.0, - collections.deque([0.0, -x, 1, 2], maxlen=3), - ), - "d": collections.OrderedDict( - { - "e": torch.return_types.qr((2 * x, None)), - "f": MyTuple(x, x + 1, torch.zeros(4, 3)), - }, - ), - } - leaves = module.tree_leaves(tree) - return leaves - - x = torch.randn(3, 2) - expected = fn(x) - fn_opt = torch.compile(fullgraph=True)(fn) - actual = fn_opt(x) - - self.assertEqual(actual, expected) - - def test_pytree_tree_flatten_unflatten(self): - implemtations = [("python", python_pytree)] - if cxx_pytree is not None: - implemtations.append(("cxx", cxx_pytree)) - - for name, module in implemtations: - with self.subTest(f"pytree implement: {name}"): - - def fn(x, y): - tree = { - "a": [x, x - 1], - "b": x + 2, - "c": ( - x, - 3.0, - collections.deque([0.0, -x, 1, 2], maxlen=3), - ), - "d": collections.OrderedDict( - { - "e": torch.return_types.qr((2 * x, None)), - "f": MyTuple(x, x + 1, torch.zeros(4, 3)), - }, - ), - } - leaves, treespec = module.tree_flatten(tree) - new_leaves = [ - x - 1, - y, - x * y, - 3.0, - y - 2, - 1, - torch.zeros(2, 2), - 2 * y, - -y, - x + y, - x - y, - torch.ones(3, 2), - 1, - ] - new_tree = module.tree_unflatten(new_leaves, treespec) - return leaves, new_tree - - x = torch.randn(3, 2) - y = torch.randn(3, 2) - expected = fn(x, y) - fn_opt = torch.compile(fullgraph=True)(fn) - actual = fn_opt(x, y) - - self.assertEqual(actual, expected) - - def test_pytree_tree_map(self): - implemtations = [("python", python_pytree)] - if cxx_pytree is not None: - implemtations.append(("cxx", cxx_pytree)) - - for name, module in implemtations: - with self.subTest(f"pytree implement: {name}"): - - def fn(x, y): - tree1 = { - "a": [x, x - 1], - "b": x + 2, - "c": ( - x, - 3.0, - collections.deque([0.0, -x, 1, 2], maxlen=3), - ), - "d": collections.OrderedDict( - { - "e": torch.return_types.qr((2 * x, None)), - "f": MyTuple(x, x + 1, torch.zeros(4, 3)), - }, - ), - } - tree2 = collections.OrderedDict( - [ - ("c", (y, 3.0, collections.deque([1, -y, 10.0]))), - ("a", [y, y + 1]), - ("b", y + 2), - ( - "d", - { - "f": MyTuple(torch.ones(4, 3), -y, y + 1), - "e": torch.return_types.qr((2 * y, None)), - }, - ), - ], - ) - return module.tree_map(lambda u, v: (u, v), tree1, tree2) - - x = torch.randn(3, 2) - y = torch.randn(3, 2) - expected = fn(x, y) - fn_opt = torch.compile(fullgraph=True)(fn) - actual = fn_opt(x, y) - - self.assertEqual(actual, expected) - def test_shape_env_no_recording(self): main = ShapeEnv(should_record_events=False) @@ -12200,6 +11955,257 @@ def fn(x, y): self.assertEqual(fn(*inputs), inputs[0]) +class MiscTestsPyTree(torch._inductor.test_case.TestCase): + @parametrize_pytree_module + def test_tracing_pytree(self, pytree): + def fn(xs): + flat_xs, spec = pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + + counter = CompileCounter() + torch.compile(fn, backend=counter, fullgraph=True)(xs) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 3) + + @parametrize_pytree_module + def test_tracing_nested_pytree(self, pytree): + def fn(xs): + flat_xs, spec = pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsl = [xs, xs, xs, xs] + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 12) + + @parametrize_pytree_module + def test_tracing_nested_tuples(self, pytree): + def fn(xs): + flat_xs, spec = pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsl = (xs, xs, xs, xs) + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 12) + + @parametrize_pytree_module + def test_tracing_nested_dicts(self, pytree): + def fn(xs): + flat_xs, spec = pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsl = { + "a": xs, + "b": xs, + "c": xs, + } + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 9) + + @parametrize_pytree_module + def test_tracing_nested_mixed_all(self, pytree): + def fn(xs): + flat_xs, spec = pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsa = (xs, xs) + xsb = {"aa": xsa, "ab": xs} + xsl = { + "a": xs, + "b": xsa, + "c": xsb, + } + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 18) + + @parametrize_pytree_module + def test_tracing_nested_tensor_subclass(self, pytree): + from torch.testing._internal.two_tensor import TwoTensor + from torch.utils.checkpoint import checkpoint + + def fn(xs): + nested_xs = [[xs]] + flat_xs, spec = pytree.tree_flatten(xs) + return flat_xs[0].clone() + + # use checkpoint to trigger a "sourceless" tensor subclass + def checkpoint_fn(xs): + return checkpoint(fn, xs, use_reentrant=True) + + xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2)) + + counter = CompileCounter() + torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 2) + + @parametrize_pytree_module + def test_pytree_tree_leaves(self, pytree): + def fn(x): + tree = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + collections.deque([0.0, -x, 1, 2], maxlen=3), + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + leaves = pytree.tree_leaves(tree) + return leaves + + x = torch.randn(3, 2) + expected = fn(x) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x) + + self.assertEqual(actual, expected) + + @parametrize_pytree_module + def test_pytree_tree_flatten_unflatten(self, pytree): + def fn(x, y): + tree = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + collections.deque([0.0, -x, 1, 2], maxlen=3), + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + leaves, treespec = pytree.tree_flatten(tree) + new_leaves = [ + x - 1, + y, + x * y, + 3.0, + y - 2, + 1, + torch.zeros(2, 2), + 2 * y, + -y, + x + y, + x - y, + torch.ones(3, 2), + 1, + ] + new_tree = pytree.tree_unflatten(new_leaves, treespec) + return leaves, new_tree + + x = torch.randn(3, 2) + y = torch.randn(3, 2) + expected = fn(x, y) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x, y) + + self.assertEqual(actual, expected) + + @parametrize_pytree_module + def test_pytree_tree_map(self, pytree): + def fn(x, y): + tree1 = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + collections.deque([0.0, -x, 1, 2], maxlen=3), + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + tree2 = collections.OrderedDict( + [ + ("c", (y, 3.0, collections.deque([1, -y, 10.0]))), + ("a", [y, y + 1]), + ("b", y + 2), + ( + "d", + { + "f": MyTuple(torch.ones(4, 3), -y, y + 1), + "e": torch.return_types.qr((2 * y, None)), + }, + ), + ], + ) + return pytree.tree_map(lambda u, v: (u, v), tree1, tree2) + + x = torch.randn(3, 2) + y = torch.randn(3, 2) + expected = fn(x, y) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x, y) + + self.assertEqual(actual, expected) + + @parametrize_pytree_module + def test_pytree_tree_map_only(self, pytree): + def fn(xs): + def mapper(x): + return x.clone() + + y = pytree.tree_map_only(torch.Tensor, mapper, xs) + return y + + xs = [torch.tensor(i) for i in range(3)] + ["hi"] + xsa = (xs, xs) + xsb = {"aa": xsa, "ab": xs} + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb) + real_out = fn(xsb) + + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 9) + + class TestTracer(JitTestCase): def test_jit_save(self): def fn(): @@ -12517,8 +12523,12 @@ def f(): self.assertEqual(ref, res) +instantiate_parametrized_tests(MiscTestsPyTree) + devices = ("cuda", "hpu") instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/export/test_swap.py b/test/export/test_swap.py index 8833c3c94ae7b..10e003057f0a1 100644 --- a/test/export/test_swap.py +++ b/test/export/test_swap.py @@ -246,11 +246,11 @@ def forward(self, x, y): _spec_0 = self._spec_0 _spec_1 = self._spec_1 _spec_4 = self._spec_4 - tree_flatten = torch.utils._pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None + tree_flatten = torch.utils.pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None getitem = tree_flatten[0]; tree_flatten = None x = getitem[0] y = getitem[1]; getitem = None - tree_unflatten_1 = torch.utils._pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None + tree_unflatten_1 = torch.utils.pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None getitem_1 = tree_unflatten_1[0]; tree_unflatten_1 = None getitem_2 = getitem_1[0] getitem_3 = getitem_1[1]; getitem_1 = None @@ -258,7 +258,7 @@ def forward(self, x, y): bar = self.bar(foo); foo = None tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_4); bar = _spec_4 = None getitem_10 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None - tree_unflatten = torch.utils._pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None + tree_unflatten = torch.utils.pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None return tree_unflatten""", ) diff --git a/test/test_pytree.py b/test/test_pytree.py index 82665854c2b13..7de4ce5c77c79 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -889,12 +889,13 @@ def test_import_pytree_doesnt_import_optree(self): script = """ import sys import torch -import torch.utils._pytree +assert "torch.utils.pytree" in sys.modules assert "torch.utils._pytree" in sys.modules -if "torch.utils._cxx_pytree" in sys.modules: - raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree") -if "optree" in sys.modules: - raise RuntimeError("importing torch.utils._pytree should not import optree") +if not torch.utils.pytree.PYTORCH_USE_CXX_PYTREE: + if "torch.utils._cxx_pytree" in sys.modules: + raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree") + if "optree" in sys.modules: + raise RuntimeError("importing torch.utils._pytree should not import optree") """ try: subprocess.check_output( diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index e7f149f807301..9702fa733e692 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -1,3 +1,5 @@ +# Owner(s): ["module: pytree"] + """ Python polyfills for torch.utils.pytree """ @@ -7,7 +9,6 @@ from collections import deque from dataclasses import dataclass, field from typing import Any, Callable, Literal, TYPE_CHECKING -from typing_extensions import TypeIs import torch.utils._pytree as python_pytree from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES @@ -18,7 +19,7 @@ if TYPE_CHECKING: import builtins from collections.abc import Iterable - from typing_extensions import Self + from typing_extensions import Self, TypeIs __all__: list[str] = [] diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index f04e360894828..0cb37835b9f49 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3323,6 +3323,7 @@ def _module_dir(m: types.ModuleType): "torch.utils._python_dispatch", "torch.utils._pytree", "torch.utils.hooks", + "torch.utils.pytree", ] assert sorted(set(MOD_INLINELIST)) == MOD_INLINELIST MOD_INLINELIST = set(MOD_INLINELIST) diff --git a/torch/_export/utils.py b/torch/_export/utils.py index a5a304b138bd9..a5829dc889e20 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -37,13 +37,13 @@ _register_pytree_node, Context, FlattenFunc, - FromDumpableContextFn, + FromDumpableContextFunc, GetAttrKey, KeyPath, keystr, MappingKey, SequenceKey, - ToDumpableContextFn, + ToDumpableContextFunc, tree_flatten_with_path, UnflattenFunc, ) @@ -438,8 +438,8 @@ def register_dataclass_as_pytree_node( unflatten_fn: Optional[UnflattenFunc] = None, *, serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, + to_dumpable_context: Optional[ToDumpableContextFunc] = None, + from_dumpable_context: Optional[FromDumpableContextFunc] = None, return_none_fields: bool = False, ) -> None: assert dataclasses.is_dataclass( diff --git a/torch/export/__init__.py b/torch/export/__init__.py index d2b208ca19b66..e5e330fb3be72 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -19,8 +19,8 @@ from torch.types import FileLike from torch.utils._pytree import ( FlattenFunc, - FromDumpableContextFn, - ToDumpableContextFn, + FromDumpableContextFunc, + ToDumpableContextFunc, UnflattenFunc, ) diff --git a/torch/utils/__init__.py b/torch/utils/__init__.py index 23188bba9b800..16e260224053d 100644 --- a/torch/utils/__init__.py +++ b/torch/utils/__init__.py @@ -11,6 +11,7 @@ data as data, deterministic as deterministic, hooks as hooks, + pytree as pytree, ) from torch.utils.backend_registration import ( generate_methods_for_privateuse1_backend, diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 40c366b3099eb..54f8b359e70dc 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -1,3 +1,5 @@ +# Owner(s): ["module: pytree"] + """ Contains utility functions for working with nested python data structures. @@ -17,18 +19,26 @@ import types from collections.abc import Iterable from typing import Any, Callable, Optional, overload, TypeVar, Union -from typing_extensions import deprecated, TypeIs +from typing_extensions import deprecated, Self, TypeAlias, TypeIs import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion from torch.utils._pytree import ( - is_namedtuple as is_namedtuple, - is_namedtuple_class as is_namedtuple_class, - is_namedtuple_instance as is_namedtuple_instance, - is_structseq as is_structseq, - is_structseq_class as is_structseq_class, - is_structseq_instance as is_structseq_instance, - KeyEntry as KeyEntry, + Context, + DumpableContext, + FlattenFunc, + FlattenWithKeysFunc, + FromDumpableContextFunc, + is_namedtuple, + is_namedtuple_class, + is_namedtuple_instance, + is_structseq, + is_structseq_class, + is_structseq_instance, + KeyPath, + PyTree, + ToDumpableContextFunc, + UnflattenFunc, ) @@ -43,7 +53,7 @@ import optree -from optree import PyTreeSpec as TreeSpec # direct import for type annotations +from optree import PyTreeSpec # direct import for type annotations __all__ = [ @@ -52,8 +62,9 @@ "FlattenFunc", "UnflattenFunc", "DumpableContext", - "ToDumpableContextFn", - "FromDumpableContextFn", + "ToDumpableContextFunc", + "FromDumpableContextFunc", + "PyTreeSpec", "TreeSpec", "LeafSpec", "keystr", @@ -100,17 +111,8 @@ U = TypeVar("U") R = TypeVar("R") - -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]] -UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] -OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree] -DumpableContext = Any # Any json dumpable text -ToDumpableContextFn = Callable[[Context], DumpableContext] -FromDumpableContextFn = Callable[[DumpableContext], Context] -KeyPath = tuple[KeyEntry, ...] -FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]] +TreeSpec: TypeAlias = PyTreeSpec +OpTreeUnflattenFunc: TypeAlias = Callable[[Context, Iterable[Any]], PyTree] def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc: @@ -127,8 +129,8 @@ def register_pytree_node( unflatten_fn: UnflattenFunc, *, serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, + to_dumpable_context: Optional[ToDumpableContextFunc] = None, + from_dumpable_context: Optional[FromDumpableContextFunc] = None, flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, ) -> None: """Register a container-like type as pytree node. @@ -195,8 +197,8 @@ def _register_pytree_node( unflatten_fn: UnflattenFunc, *, serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, + to_dumpable_context: Optional[ToDumpableContextFunc] = None, + from_dumpable_context: Optional[FromDumpableContextFunc] = None, ) -> None: """Register a container-like type as pytree node for the C++ pytree only. @@ -246,8 +248,8 @@ def _private_register_pytree_node( unflatten_fn: UnflattenFunc, *, serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, + to_dumpable_context: Optional[ToDumpableContextFunc] = None, + from_dumpable_context: Optional[FromDumpableContextFunc] = None, ) -> None: """This is an internal function that is used to register a pytree node type for the C++ pytree only. End-users should use :func:`register_pytree_node` @@ -995,16 +997,22 @@ def treespec_loads(serialized: str) -> TreeSpec: return treespec -class _DummyLeaf: +class _Asterisk(str): + __slots__ = () + + def __new__(cls) -> Self: + return super().__new__(cls, "*") + def __repr__(self) -> str: - return "*" + return "*" # no quotes + + +_asterisk = _Asterisk() +del _Asterisk def treespec_pprint(treespec: TreeSpec) -> str: - dummy_tree = tree_unflatten( - [_DummyLeaf() for _ in range(treespec.num_leaves)], - treespec, - ) + dummy_tree = tree_unflatten([_asterisk] * treespec.num_leaves, treespec) return repr(dummy_tree) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 169bc4186de99..e81ab5e4e8948 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -1,3 +1,5 @@ +# Owner(s): ["module: pytree"] + """ Contains utility functions for working with nested python data structures. @@ -42,7 +44,7 @@ TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self, TypeIs +from typing_extensions import deprecated, NamedTuple, Self, TypeAlias, TypeIs from torch.torch_version import TorchVersion as _TorchVersion @@ -53,8 +55,9 @@ "FlattenFunc", "UnflattenFunc", "DumpableContext", - "ToDumpableContextFn", - "FromDumpableContextFn", + "ToDumpableContextFunc", + "FromDumpableContextFunc", + "PyTreeSpec", "TreeSpec", "LeafSpec", "keystr", @@ -120,17 +123,21 @@ def default(self, obj: object) -> str: return super().default(obj) # type: ignore[no-any-return] -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]] -UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] -DumpableContext = Any # Any json dumpable text -ToDumpableContextFn = Callable[[Context], DumpableContext] -FromDumpableContextFn = Callable[[DumpableContext], Context] -ToStrFunc = Callable[["TreeSpec", list[str]], str] -MaybeFromStrFunc = Callable[[str], Optional[tuple[Any, Context, str]]] -KeyPath = tuple[KeyEntry, ...] -FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]] +Context: TypeAlias = Any +PyTree: TypeAlias = Any +FlattenFunc: TypeAlias = Callable[[PyTree], tuple[list[Any], Context]] +UnflattenFunc: TypeAlias = Callable[[Iterable[Any], Context], PyTree] +DumpableContext: TypeAlias = Any # Any json dumpable text +ToDumpableContextFunc: TypeAlias = Callable[[Context], DumpableContext] +FromDumpableContextFunc: TypeAlias = Callable[[DumpableContext], Context] +ToDumpableContextFn: TypeAlias = ToDumpableContextFunc +FromDumpableContextFn: TypeAlias = FromDumpableContextFunc +ToStrFunc: TypeAlias = Callable[["TreeSpec", list[str]], str] +MaybeFromStrFunc: TypeAlias = Callable[[str], Optional[tuple[Any, Context, str]]] +KeyPath: TypeAlias = tuple[KeyEntry, ...] +FlattenWithKeysFunc: TypeAlias = Callable[ + [PyTree], tuple[list[tuple[KeyEntry, Any]], Any] +] # A NodeDef holds two callables: @@ -163,8 +170,8 @@ class NodeDef(NamedTuple): class _SerializeNodeDef(NamedTuple): typ: type[Any] serialized_type_name: str - to_dumpable_context: Optional[ToDumpableContextFn] - from_dumpable_context: Optional[FromDumpableContextFn] + to_dumpable_context: Optional[ToDumpableContextFunc] + from_dumpable_context: Optional[FromDumpableContextFunc] SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {} @@ -201,8 +208,8 @@ def register_pytree_node( unflatten_fn: UnflattenFunc, *, serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, + to_dumpable_context: Optional[ToDumpableContextFunc] = None, + from_dumpable_context: Optional[FromDumpableContextFunc] = None, flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, ) -> None: """Register a container-like type as pytree node. @@ -524,8 +531,8 @@ def _register_pytree_node( maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated *, serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, + to_dumpable_context: Optional[ToDumpableContextFunc] = None, + from_dumpable_context: Optional[FromDumpableContextFunc] = None, flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, ) -> None: """Register a container-like type as pytree node for the Python pytree only. @@ -591,8 +598,8 @@ def _private_register_pytree_node( unflatten_fn: UnflattenFunc, *, serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, + to_dumpable_context: Optional[ToDumpableContextFunc] = None, + from_dumpable_context: Optional[FromDumpableContextFunc] = None, flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, ) -> None: """This is an internal function that is used to register a pytree node type @@ -1057,7 +1064,9 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) - # children_specs: specs for each child of the root Node # num_leaves: the number of leaves @dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False) -class TreeSpec: +class PyTreeSpec: + """Representing the structure of the pytree.""" + type: Any context: Context children_specs: list["TreeSpec"] @@ -1104,15 +1113,18 @@ def __eq__(self, other: PyTree) -> bool: return NotImplemented def is_leaf(self) -> bool: + """Test whether the treespec represents a leaf.""" return self.num_nodes == 1 and self.num_leaves == 1 def flatten_up_to(self, tree: PyTree) -> list[PyTree]: - def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: + """Flatten the subtrees in ``tree`` up to the structure of this treespec and return a list of subtrees.""" + + def helper(treespec: TreeSpec, node: PyTree, subtrees: list[PyTree]) -> None: if treespec.is_leaf(): - subtrees.append(tree) + subtrees.append(node) return - node_type = _get_node_type(tree) + node_type = _get_node_type(node) if treespec.type not in BUILTIN_TYPES: # Always require custom node types to match exactly if node_type != treespec.type: @@ -1121,7 +1133,7 @@ def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: f"expected {treespec.type!r}, but got {node_type!r}.", ) flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - children, context = flatten_fn(tree) + children, context = flatten_fn(node) if len(children) != treespec.num_children: raise ValueError( f"Node arity mismatch; " @@ -1143,10 +1155,10 @@ def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: f"Node type mismatch; " f"expected {treespec.type!r}, but got {node_type!r}.", ) - if len(tree) != treespec.num_children: + if len(node) != treespec.num_children: raise ValueError( f"Node arity mismatch; " - f"expected {treespec.num_children}, but got {len(tree)}.", + f"expected {treespec.num_children}, but got {len(node)}.", ) if both_standard_dict: @@ -1158,7 +1170,7 @@ def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: else treespec.context[1] ) expected_keys = dict_context - got_key_set = set(tree) + got_key_set = set(node) expected_key_set = set(expected_keys) if got_key_set != expected_key_set: missing_keys = expected_key_set.difference(got_key_set) @@ -1169,11 +1181,11 @@ def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: if extra_keys: message += f"; extra key(s): {extra_keys}" raise ValueError(f"Node keys mismatch{message}.") - children = [tree[key] for key in expected_keys] + children = [node[key] for key in expected_keys] else: # node_type is treespec.type flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - children, context = flatten_fn(tree) + children, context = flatten_fn(node) if ( node_type is not deque # ignore mismatch of `maxlen` for deque ) and context != treespec.context: @@ -1190,6 +1202,7 @@ def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: return subtrees def unflatten(self, leaves: Iterable[Any]) -> PyTree: + """Reconstruct a pytree from the leaves.""" if not isinstance(leaves, (list, tuple)): leaves = list(leaves) if len(leaves) != self.num_leaves: @@ -1235,6 +1248,9 @@ def __hash__(self) -> int: return hash((node_type, hashable_context, tuple(self.children_specs))) +TreeSpec: TypeAlias = PyTreeSpec + + # NOTE: subclassing a dataclass is subtle. In order to enable reasoning about # this class with `dataclasses.fields`, etc., while having a simplified # constructor that takes no argument, we wrap with `dataclass(init=True, ...)` @@ -1935,16 +1951,22 @@ def treespec_loads(serialized: str) -> TreeSpec: ) -class _DummyLeaf: +class _Asterisk(str): + __slots__ = () + + def __new__(cls) -> Self: + return super().__new__(cls, "*") + def __repr__(self) -> str: - return "*" + return "*" # no quotes + + +_asterisk = _Asterisk() +del _Asterisk def treespec_pprint(treespec: TreeSpec) -> str: - dummy_tree = tree_unflatten( - [_DummyLeaf() for _ in range(treespec.num_leaves)], - treespec, - ) + dummy_tree = tree_unflatten([_asterisk] * treespec.num_leaves, treespec) return repr(dummy_tree) diff --git a/torch/utils/pytree/__init__.py b/torch/utils/pytree/__init__.py new file mode 100644 index 0000000000000..3f8229778ad1b --- /dev/null +++ b/torch/utils/pytree/__init__.py @@ -0,0 +1,216 @@ +# Owner(s): ["module: pytree"] + +""" +Contains utility functions for working with nested python data structures. + +A *pytree* is Python nested data structure. It is a tree in the sense that +nodes are Python collections (e.g., list, tuple, dict) and the leaves are +Python values. Furthermore, a pytree should not contain reference cycles. + +pytrees are useful for working with nested collections of Tensors. For example, +one can use `tree_map` to map a function over all Tensors inside some nested +collection of Tensors and `tree_leaves` to get a flat list of all Tensors +inside some nested collection. pytrees are helpful for implementing nested +collection support for PyTorch APIs. +""" + +import os as _os +import sys as _sys +from typing import Any as _Any, Optional as _Optional + +import torch.utils._pytree as python +from torch.utils._exposed_in import exposed_in as _exposed_in +from torch.utils._pytree import ( # these type aliases are identical in both implementations + FlattenFunc, + FlattenWithKeysFunc, + FromDumpableContextFunc, + PyTree, + ToDumpableContextFunc, + UnflattenFunc, +) + + +__all__ = [ + "PyTreeSpec", + "register_pytree_node", + "tree_flatten", + "tree_unflatten", + "tree_iter", + "tree_leaves", + "tree_structure", + "tree_map", + "tree_map_", + "tree_map_only", + "tree_map_only_", + "tree_all", + "tree_any", + "tree_all_only", + "tree_any_only", + "treespec_pprint", + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", +] + + +# NB: Once this variable is read from the environment, the underlying pytree +# implementation is frozen. It cannot be swapped to another at runtime. +PYTORCH_USE_CXX_PYTREE: bool = _os.getenv("PYTORCH_USE_CXX_PYTREE", "0") not in { + "0", + "", +} + + +if PYTORCH_USE_CXX_PYTREE: + import torch.utils._cxx_pytree as cxx # noqa: F401 + + if not python._cxx_pytree_dynamo_traceable: + raise ImportError( + "Cannot import package `optree`. " + "Please install `optree` via `python -m pip install --upgrade optree`. " + "Or set the environment variable `PYTORCH_USE_CXX_PYTREE=0`." + ) + + +_sys.modules[f"{__name__}.cxx"] = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment] + + +if not PYTORCH_USE_CXX_PYTREE: + from torch.utils._pytree import ( + is_namedtuple, + is_namedtuple_class, + is_namedtuple_instance, + is_structseq, + is_structseq_class, + is_structseq_instance, + PyTreeSpec, + register_pytree_node as _register_pytree_node, + tree_all, + tree_all_only, + tree_any, + tree_any_only, + tree_flatten, + tree_iter, + tree_leaves, + tree_map, + tree_map_, + tree_map_only, + tree_map_only_, + tree_structure, + tree_unflatten, + treespec_pprint, + ) + + PyTreeSpec = _exposed_in(__name__)(PyTreeSpec) # type: ignore[misc] +else: + from torch.utils._cxx_pytree import ( # type: ignore[assignment,no-redef] + is_namedtuple, + is_namedtuple_class, + is_namedtuple_instance, + is_structseq, + is_structseq_class, + is_structseq_instance, + PyTreeSpec, + register_pytree_node as _register_pytree_node, + tree_all, + tree_all_only, + tree_any, + tree_any_only, + tree_flatten, + tree_iter, + tree_leaves, + tree_map, + tree_map_, + tree_map_only, + tree_map_only_, + tree_structure, + tree_unflatten, + treespec_pprint, + ) + + +# Change `__module__` of reexported public APIs to 'torch.utils.pytree' +__func_names = frozenset( + { + "tree_all", + "tree_all_only", + "tree_any", + "tree_any_only", + "tree_flatten", + "tree_iter", + "tree_leaves", + "tree_map", + "tree_map_", + "tree_map_only", + "tree_map_only_", + "tree_structure", + "tree_unflatten", + "treespec_pprint", + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", + } +) +globals().update( + { + name: _exposed_in(__name__)(member) + for name, member in globals().items() + if name in __func_names + } +) +del __func_names, _exposed_in + + +def register_pytree_node( + cls: type[_Any], + /, + # intentionally use `*_func` over `*_fn` to match annotations + flatten_func: FlattenFunc, + unflatten_func: UnflattenFunc, +) -> None: + """Register a container-like type as pytree node. + + Args: + cls (type): A Python type to treat as an internal pytree node. + flatten_func (callable): A function to be used during flattening, taking an instance of + ``cls`` and returning a pair, with (1) an iterable for the children to be flattened + recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be + passed to the ``unflatten_func``. + unflatten_func (callable): A function taking two arguments: the unflattened children, and + the auxiliary data that was returned by ``flatten_func`` and stored in the treespec. + The function should return an instance of ``cls``. + + Example:: + + >>> # xdoctest: +SKIP + >>> from collections import UserList + ... class MyList(UserList): pass + >>> # Registry a Python type with lambda functions + ... register_pytree_node( + ... MyList, + ... lambda lst: (list(lst), None), + ... lambda children, _: MyList(children), + ... ) + """ + _register_pytree_node( + cls, + flatten_func, + unflatten_func, + ) + + +def __getattr__(name: str) -> _Any: + if name == "cxx": + # Lazy import + import torch.utils._cxx_pytree as cxx # noqa: F811 + + _sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx + return cxx + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/torch/utils/pytree/cxx.pyi b/torch/utils/pytree/cxx.pyi new file mode 100644 index 0000000000000..1c8a69c4bb429 --- /dev/null +++ b/torch/utils/pytree/cxx.pyi @@ -0,0 +1,8 @@ +# Owner(s): ["module: pytree"] + +from .._cxx_pytree import * # noqa: F403 +from .._cxx_pytree import ( + __all__ as __all__, + _broadcast_to_and_flatten as _broadcast_to_and_flatten, + KeyPath as KeyPath, +) diff --git a/torch/utils/pytree/python.pyi b/torch/utils/pytree/python.pyi new file mode 100644 index 0000000000000..61f72707cd055 --- /dev/null +++ b/torch/utils/pytree/python.pyi @@ -0,0 +1,15 @@ +# Owner(s): ["module: pytree"] + +from .._pytree import * # noqa: F403 +from .._pytree import ( + __all__ as __all__, + _broadcast_to_and_flatten as _broadcast_to_and_flatten, + arg_tree_leaves as arg_tree_leaves, + BUILTIN_TYPES as BUILTIN_TYPES, + GetAttrKey as GetAttrKey, + KeyEntry as KeyEntry, + KeyPath as KeyPath, + MappingKey as MappingKey, + SequenceKey as SequenceKey, + SUPPORTED_NODES as SUPPORTED_NODES, +)