8000 [POC][functorch] use public PyTree API in `torch.func` · pytorch/pytorch@9144e6d · GitHub
[go: up one dir, main page]

Skip to content

Commit 9144e6d

Browse files
committed
[POC][functorch] use public PyTree API in torch.func
ghstack-source-id: 1c86397 Pull Request resolved: #137884
1 parent 16ec4ac commit 9144e6d

28 files changed

+154
-128
lines changed

test/functorch/common_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functorch_additional_op_db import additional_op_db
1313

1414
import torch
15-
import torch.utils._pytree as pytree
15+
import torch.utils.pytree as pytree
1616
from functorch import vmap
1717
from torch.testing._internal.autograd_function_db import autograd_function_db
1818
from torch.testing._internal.common_device_type import toleranceOverride

test/functorch/test_aotdispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121
import torch._dynamo as torchdynamo
2222
import torch.nn as nn
23-
import torch.utils._pytree as pytree
23+
import torch.utils.pytree as pytree
2424
from functorch import grad, jacrev, make_fx, vjp, vmap
2525
from functorch.compile import (
2626
aot_function,

test/functorch/test_control_flow.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import unittest
55

66
import torch
7-
import torch.utils._pytree as pytree
7+
import torch.utils.pytree as pytree
88
from functorch.experimental import control_flow
99
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
1010
from torch._higher_order_ops.associative_scan import associative_scan
@@ -1293,8 +1293,6 @@ def f(x, y):
12931293
self.assertEqual(expected_grads, grads)
12941294

12951295
def test_map_autograd_nested_list(self):
1296-
import torch.utils._pytree as pytree
1297-
12981296
def f(x, y):
12991297
a, b = x
13001298
c, d = a
@@ -4304,8 +4302,6 @@ def g(xs, y):
43044302
self.check_map_count(gm, 2)
43054303

43064304
def test_tracing_map_autograd_symbolic_list(self):
4307-
import torch.utils._pytree as pytree
4308-
43094305
def f(x, y):
43104306
return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]
43114307

test/functorch/test_eager_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
TestCase,
8080
xfailIfTorchDynamo,
8181
)
82-
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
82+
from torch.utils.pytree import tree_flatten, tree_map, tree_unflatten
8383

8484

8585
USE_TORCHVISION = False

test/functorch/test_ops.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@
5858
unMarkDynamoStrictTest,
5959
)
6060
from torch.testing._internal.opinfo.core import SampleInput
61-
from torch.utils import _pytree as pytree
62-
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
61+
from torch.utils.pytree import tree_flatten, tree_leaves, tree_map, tree_unflatten
6362

6463

6564
aten = torch.ops.aten
@@ -161,7 +160,7 @@ def normalize_op_input_output3(
161160
f, args, kwargs, sample_args, output_process_fn_grad=None
162161
):
163162
flat_args, args_spec = tree_flatten(args)
164-
flat_sample_args = pytree.tree_leaves(sample_args)
163+
flat_sample_args = tree_leaves(sample_args)
165164
diff_argnums = tuple(
166165
i
167166
for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args))
@@ -299,8 +298,8 @@ def wrapped(*args):
299298
if isinstance(primals_out, torch.Tensor):
300299
return (primals_out, tangents_out)
301300
else:
302-
flat_primals_out = pytree.tree_leaves(primals_out)
303-
flat_tangents_out = pytree.tree_leaves(tangents_out)
301+
flat_primals_out = tree_leaves(primals_out)
302+
flat_tangents_out = tree_leaves(tangents_out)
304303
return tuple(flat_primals_out + flat_tangents_out)
305304

306305
return wrapped, tangents
@@ -334,8 +333,8 @@ def wrapped(*args):
334333
if isinstance(primals_out, torch.Tensor):
335334
return (primals_out, tangents_out)
336335
else:
337-
flat_primals_out = pytree.tree_leaves(primals_out)
338-
flat_tangents_out = pytree.tree_leaves(tangents_out)
336+
flat_primals_out = tree_leaves(primals_out)
337+
flat_tangents_out = tree_leaves(tangents_out)
339338
return tuple(flat_primals_out + flat_tangents_out)
340339

341340
return wrapped, primals + tangents
@@ -1086,7 +1085,7 @@ def test_vmapvjpvjp(self, device, dtype, op):
10861085
fn, args = get_vjpfull_variant(op, sample)
10871086
result = fn(*args)
10881087
cotangents = tree_map(lambda x: torch.randn_like(x), result)
1089-
cotangents = pytree.tree_leaves(cotangents)
1088+
cotangents = tree_leaves(cotangents)
10901089
num_args = len(args)
10911090

10921091
args_and_cotangents = tuple(args) + tuple(cotangents)
@@ -1096,8 +1095,8 @@ def vjp_of_vjp(*args_and_cotangents):
10961095
cotangents = args_and_cotangents[num_args:]
10971096
result, vjp_fn = vjp(fn, *args)
10981097
result_vjps = vjp_fn(cotangents)
1099-
result = pytree.tree_leaves(result)
1100-
result_vjps = pytree.tree_leaves(result_vjps)
1098+
result = tree_leaves(result)
1099+
result_vjps = tree_leaves(result_vjps)
11011100
return (*result, *result_vjps)
11021101

11031102
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
@@ -2103,8 +2102,8 @@ def jvp_of_vjp(*args):
21032102
(primals, tangents) = tree_unflatten(args, spec)
21042103
primals_out, tangents_out = jvp(push_vjp, primals, tangents)
21052104

2106-
flat_primals_out = pytree.tree_leaves(primals_out)
2107-
flat_tangents_out = pytree.tree_leaves(tangents_out)
2105+
flat_primals_out = tree_leaves(primals_out)
2106+
flat_tangents_out = tree_leaves(tangents_out)
21082107
return tuple(flat_primals_out + flat_tangents_out)
21092108

21102109
is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs)
@@ -2421,7 +2420,7 @@ def is_differentiable(inp):
24212420
)
24222421

24232422
def get_flat_differentiable(tree):
2424-
flattened = pytree.tree_leaves(tree)
2423+
flattened = tree_leaves(tree)
24252424
return tuple(i for i in flattened if is_differentiable(i))
24262425

24272426
def get_differentiable_linked(list1, list2):
@@ -2434,7 +2433,7 @@ def get_differentiable_linked(list1, list2):
24342433
return zip(*paired_list)
24352434

24362435
def filter_none(out):
2437-
flattened = pytree.tree_leaves(out)
2436+
flattened = tree_leaves(out)
24382437
return tuple(o for o in flattened if o is not None)
24392438

24402439
if not op.supports_autograd:
@@ -2452,8 +2451,8 @@ def compute_grad(cotangents):
24522451
out_flattened = out
24532452
cotangents_flattened = cotangents
24542453
if not isinstance(out_flattened, torch.Tensor):
2455-
out_flattened = pytree.tree_leaves(out)
2456-
cotangents_flattened = pytree.tree_leaves(cotangents)
2454+
out_flattened = tree_leaves(out)
2455+
cotangents_flattened = tree_leaves(cotangents)
24572456
out_flattened, cotangents_flattened = get_differentiable_linked(
24582457
out_flattened, cotangents_flattened
24592458
)

test/functorch/test_vmap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import functorch
3939
import torch
4040
import torch.nn.functional as F
41+
import torch.utils.pytree as pytree
4142
from functorch import grad, grad_and_value, jacfwd, jvp, vjp, vmap
4243
from functorch.experimental import chunk_vmap
4344
from torch import Tensor
@@ -76,7 +77,6 @@
7677
xfailIfTorchDynamo,
7778
)
7879
from torch.testing._internal.custom_op_db import custom_op_db
79-
from torch.utils import _pytree as pytree
8080

8181

8282
def get_platform_specific_sdpa():
@@ -1340,7 +1340,7 @@ def _vmap_test(
13401340
check_propagates_grad=True,
13411341
):
13421342
result = vmap(op, in_dims, out_dims)(*inputs)
1343-
are_nested = [t.is_nested for t in pytree.tree_leaves(result)]
1343+
are_nested = [t.is_nested for t in pytree.tree_iter(result)]
13441344
reference_result = reference_vmap(
13451345
op, inputs, in_dims, out_dims, return_nt=any(are_nested)
13461346
)

torch/_functorch/_aot_autograd/collect_metadata_analysis.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Callable, DefaultDict, Dict, List, Optional, Set
1616

1717
import torch
18-
import torch.utils._pytree as pytree
18+
import torch.utils.pytree as pytree
1919
from torch import Tensor
2020
from torch._guards import detect_fake_mode
2121
from torch._logging import getArtifactLogger
@@ -694,11 +694,8 @@ def view_avoid_dupes_with_primals(t):
694694
view_avoid_dupes_with_primals, traced_tangents
695695
)
696696

697-
output_tangents_start_idx = len(f_input_tangents)
698-
output_tangents_end_idx = output_tangents_start_idx + len(f_output_tangents)
699697
tangents_and_memory_formats = [
700-
coerce_tangent_and_suggest_memory_format(tt)
701-
for i, tt in enumerate(traced_tangents)
698+
coerce_tangent_and_suggest_memory_format(tt) for tt in traced_tangents
702699
]
703700
traced_tangents = [t[0] for t in tangents_and_memory_formats]
704701
traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats]

torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from typing import Any, Dict, List, Optional, Tuple
99

1010
import torch
11-
import torch.utils._pytree as pytree
1211
import torch.utils.dlpack
12+
import torch.utils.pytree as pytree
1313
from torch import Tensor
1414
from torch._dispatch.python import enable_python_dispatcher
1515
from torch._dynamo.utils import lazy_format_graph_code

torch/_functorch/_aot_autograd/input_output_analysis.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
"""
1212

1313
import itertools
14-
from typing import Any, Dict, List, Optional, Tuple, Union
14+
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
1515

1616
import torch
17-
import torch.utils._pytree as pytree
1817
from torch import Tensor
1918
from torch._subclasses.functional_tensor import FunctionalTensor
2019
from torch.fx.experimental.symbolic_shapes import is_concrete_int
@@ -32,6 +31,10 @@
3231
from .utils import strict_zip
3332

3433

34+
if TYPE_CHECKING:
35+
from torch.utils.pytree import PyTreeSpec
36+
37+
3538
zip = strict_zip
3639

3740

@@ -421,8 +424,8 @@ def _graph_output_names(gm):
421424
def create_graph_signature(
422425
fx_g: torch.fx.GraphModule,
423426
fw_metadata: ViewAndMutationMeta,
424-
in_spec: pytree.TreeSpec,
425-
out_spec: pytree.TreeSpec,
427+
in_spec: "PyTreeSpec",
428+
out_spec: "PyTreeSpec",
426429
*,
427430
user_args_flat: List[Tensor],
428431
params_and_buffers_flat: List[Tensor],

torch/_functorch/_aot_autograd/schemas.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,19 @@
99
import functools
1010
from dataclasses import dataclass, field
1111
from enum import Enum
12-
from typing import Any, Callable, Dict, List, NewType, Optional, Set, Union
12+
from typing import (
13+
Any,
14+
Callable,
15+
Dict,
16+
List,
17+
NewType,
18+
Optional,
19+
Set,
20+
TYPE_CHECKING,
21+
Union,
22+
)
1323

1424
import torch
15-
import torch.utils._pytree as pytree
1625
from torch._guards import Source
1726
from torch._ops import OpOverload
1827
from torch._subclasses import FakeTensor
@@ -27,6 +36,10 @@
2736
from .utils import strict_zip
2837

2938

39+
if TYPE_CHECKING:
40+
from torch.utils.pytree import PyTreeSpec
41+
42+
3043
zip = strict_zip
3144

3245

@@ -691,8 +704,8 @@ class GraphSignature:
691704
buffers_to_mutate: Dict[GraphOutputName, FQN]
692705
user_inputs_to_mutate: Dict[GraphOutputName, GraphInputName]
693706

694-
in_spec: pytree.TreeSpec
695-
out_spec: pytree.TreeSpec
707+
in_spec: "PyTreeSpec"
708+
out_spec: "PyTreeSpec"
696709

697710
backward_signature: Optional[BackwardSignature]
698711

@@ -703,8 +716,8 @@ class GraphSignature:
703716
def from_tracing_metadata(
704717
cls,
705718
*,
706-
in_spec: pytree.TreeSpec,
707-
out_spec: pytree.TreeSpec,
719+
in_spec: "PyTreeSpec",
720+
out_spec: "PyTreeSpec",
708721
graph_input_names: List[str],
709722
graph_output_names: List[str],
710723
view_mutation_metadata: ViewAndMutationMeta,

0 commit comments

Comments
 (0)
0