8000 [hop] Support more output types for `flat_apply` (#146714) · pytorch/pytorch@bab84f0 · GitHub
[go: up one dir, main page]

Skip to content

Commit bab84f0

Browse files
StrongerXipytorchmergebot
authored andcommitted
[hop] Support more output types for flat_apply (#146714)
This patch enables `flat_apply` to support certain non-Tensor output types like containers and graphable types. This will in turn enable the upcoming `mark_traceable` to support more output types. The patch also exposes a `func_to_graphable` rather than having the users calling the lower level `pytree.flatten(ConstantFunction(...))`. Pull Request resolved: #146714 Approved by: https://github.com/zou3519
1 parent 8594856 commit bab84f0

File tree

2 files changed

+54
-21
lines changed

2 files changed

+54
-21
lines changed

test/dynamo/test_flat_apply.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
import torch
55
import torch._dynamo.test_case
66
import torch.utils._pytree as pytree
7+
from torch._higher_order_ops.flat_apply import (
8+
flat_apply,
9+
func_to_graphable,
10+
is_graphable,
11+
to_graphable,
12+
)
713

814

915
def distance(a, b, norm):
@@ -40,14 +46,8 @@ def test_simple(self):
4046

4147
args = (a, b)
4248
kwargs = {"norm": norm}
43-
from torch._higher_order_ops.flat_apply import (
44-
ConstantFunction,
45-
flat_apply,
46-
is_graphable,
47-
to_graphable,
48-
)
49-
50-
empty_list, func_spec = pytree.tree_flatten(ConstantFunction(distance))
49+
50+
empty_list, func_spec = func_to_graphable(distance)
5151
self.assertEqual(empty_list, [])
5252

5353
flat_args, in_spec = to_graphable((args, kwargs))
@@ -59,6 +59,30 @@ def test_simple(self):
5959
result = flat_apply(func_spec, in_spec, *flat_args)
6060
self.assertEqual(result, distance(*args, **kwargs))
6161

62+
def test_non_tensor_output(self):
63+
tensor = torch.tensor
64+
65+
a = Point(tensor(0.0), tensor(0.0))
66+
b = Point(tensor(3.0), tensor(4.0))
67+
68+
args = (a, b)
69+
kwargs = {}
70+
71+
def f(a, b):
72+
return [a.x + 1, (b.x + 2, [a.y + 3, 4.0], "5"), 6 + b.y]
73+
74+
empty_list, func_spec = func_to_graphable(f)
75+
self.assertEqual(empty_list, [])
76+
77+
flat_args, in_spec = to_graphable((args, kwargs))
78+
79+
for arg in flat_args:
80+
self.assertTrue(is_graphable(arg))
81+
82+
# Test flat_apply returns same thing as original function
83+
result = flat_apply(func_spec, in_spec, *flat_args)
84+
self.assertEqual(result, f(*args, **kwargs))
85+
6286

6387
if __name__ == "__main__":
6488
from torch._dynamo.test_case import run_tests

torch/_higher_order_ops/flat_apply.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,23 @@ def from_graphable(flat_args, spec):
3434
return stuff
3535

3636

37+
def func_to_graphable(func):
38+
"""
39+
Pack and flatten a function type into graphable types.
40+
This is useful for legalizing the function argument of `flat_apply`.
41+
"""
42+
return pytree.tree_flatten(_ConstantFunction(func))
43+
44+
3745
@dataclass
38-
class ConstantFunction:
46+
class _ConstantFunction:
3947
func: Callable
4048

4149
def __call__(self, *args, **kwargs):
4250
return self.func(*args, **kwargs)
4351

4452

45-
pytree.register_constant(ConstantFunction)
53+
pytree.register_constant(_ConstantFunction)
4654

4755
_op_types = (
4856
torch._ops.OpOverload,
@@ -84,27 +92,28 @@ def __call__(self, func, in_spec, *flat_args, **_unused):
8492

8593
def impl(func, in_spec, *flat_args):
8694
if not isinstance(func, _op_types):
87-
# assume ConstantFunction
95+
# assume _ConstantFunction
8896
func = pytree._retrieve_constant(func)
89-
assert isinstance(func, ConstantFunction)
97+
assert isinstance(func, _ConstantFunction)
9098

9199
args, kwargs = from_graphable(flat_args, in_spec)
92100
out = func(*args, **kwargs)
93-
# Right now, all outputs must either be Tensor or lists/tuples of Tensors.
94-
# This matches the output type restriction on custom operators.
101+
102+
# Right now, all outputs must either be graphable or lists/tuples of graphables.
95103
#
96-
# TODO: The following can be updated to support non-Tensor outputs and pytrees.
97-
# For non-Tensor constant outputs: the assumption would be that they are constant
104+
# TODO: The following can be updated to support non-graphable outputs and pytrees.
105+
# For non-graphable constant outputs: the assumption would be that they are constant
98106
# (everytime the function runs those MUST be the same)
99107
# For pytree outputs:
100108
# I'm not sure if we need to return (flat_output, spec) or just (flat_output,):
101109
# in the latter case the tracers need to carry out the output specs
102110
# (they need to know how to reconstruct the object from just the flat_output).
103-
assert (
104-
isinstance(out, torch.Tensor)
105-
or isinstance(out, (tuple, list))
106-
and all(isinstance(x, torch.Tensor) for x in out)
107-
)
111+
def is_valid_output(x):
112+
if isinstance(x, (tuple, list)):
113+
return all(map(is_valid_output, x))
114+
return is_graphable(x)
115+
116+
assert is_valid_output(out)
108117
return out
109118

110119

0 commit comments

Comments
 (0)
0