10000 auto functionalize base_hop · pytorch/pytorch@7bfa26d · GitHub
[go: up one dir, main page]

Skip to content

Commit 7bfa26d

Browse files
committed
auto functionalize base_hop
ghstack-source-id: 2358474 Pull Request resolved: #151067
1 parent 465d96c commit 7bfa26d

File tree

7 files changed

+498
-39
lines changed

7 files changed

+498
-39
lines changed

test/dynamo/test_base_hop.py

Lines changed: 164 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: dynamo"]
22
import unittest
3+
import unittest.mock as mock
34

45
import torch
56
import torch._dynamo.test_case
@@ -11,6 +12,10 @@
1112
normalize_gm,
1213
)
1314
from torch._higher_order_ops.schema import find_hop_schema
15+
from torch.testing._internal.common_utils import (
16+
instantiate_parametrized_tests,
17+
parametrize,
18+
)
1419
from torch.testing._internal.inductor_utils import HAS_CUDA
1520

1621

@@ -135,17 +140,47 @@ def inner(x, y):
135140

136141
backend = EagerAndRecordGraphs()
137142

138-
@torch.compile(backend=backend, fullgraph=True)
139143
def f(x, y):
140144
return invoke_quant_test(inner, x, y, scheme="nf4")
141145

142-
with self.assertRaisesRegex(
143-
RuntimeError,
144-
"Encountered input mutation during higher order op tracing for HOP",
146+
with mock.patch(
147+
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
148+
True,
145149
):
146-
f(x.clone(), y)
150+
torch.compile(f, backend=backend, fullgraph=True)(x.clone(), y)
151+
self.assertEqual(len(backend.graphs), 1)
152+
self.assertExpectedInline(
153+
normalize_graph(backend.graphs[0]),
154+
"""\
155+
class GraphModule(torch.nn.Module):
156+
def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
157+
l_x_ = L_x_
158+
l_y_ = L_y_
147159
148-
def test_schema_gen_pytree_in_out_with_mutation(self):
160+
subgraph_0 = self.subgraph_0
161+
invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); subgraph_0 = l_x_ = l_y_ = None
162+
getitem: "f32[3, 3]" = invoke_quant_test[0]; invoke_quant_test = None
163+
return (getitem,)
164+
165+
class subgraph_0(torch.nn.Module):
166+
def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
167+
add_: "f32[3, 3]" = l_x_.add_(1); add_ = None
168+
169+
mul_: "f32[3, 3]" = l_y_.mul_(-1); mul_ = None
170+
171+
matmul: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
172+
sin: "f32[3, 3]" = matmul.sin(); matmul = None
173+
cos: "f32[3, 3]" = sin.cos(); sin = None
174+
return (cos,)
175+
""", # noqa: B950
176+
)
177+
self.assertExpectedInline(
178+
str(find_hop_schema(backend.graphs[0], invoke_quant_test)[0]),
179+
"""invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor(a2!) arg1, *, str scheme="nf4") -> ((Tensor))""",
180+
)
181+
182+
@parametrize("backend", ["eager", "aot_eager"])
183+
def test_schema_gen_pytree_in_out_with_mutation(self, backend):
149184
def inner(x_y):
150185
x, y = x_y
151186
x.add_(1)
@@ -159,17 +194,88 @@ def inner(x_y):
159194
x = torch.randn(3, 3, requires_grad=False)
160195
y = torch.randn(3, 3, requires_grad=True)
161196

162-
backend = EagerAndRecordGraphs()
197+
if backend == "eager":
198+
bk = EagerAndRecordGraphs()
199+
else:
200+
assert backend == "aot_eager"
201+
bk = AotEagerAndRecordGraphs()
163202

164-
@torch.compile(backend=backend, fullgraph=True)
165203
def f(x, y):
166204
return invoke_quant_test(inner, [x, y], scheme="nf4")
167205

168-
with self.assertRaisesRegex(
169-
RuntimeError,
170-
"Encountered input mutation during higher order op tracing for HOP",
206+
with mock.patch(
207+
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
208+
True,
171209
):
172-
f(x.clone(), y)
210+
torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y)
211+
212+
if backend == "eager":
213+
self.assertEqual(len(bk.graphs), 1)
214+
self.assertExpectedInline(
215+
normalize_graph(bk.graphs[0]),
216+
"""\
217+
class GraphModule(torch.nn.Module):
218+
def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
219+
l_x_ = L_x_
220+
l_y_ = L_y_
221+
222+
subgraph_0 = self.subgraph_0
223+
invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); subgraph_0 = l_x_ = l_y_ = None
224+
getitem: "f32[3, 3]" = invoke_quant_test[0]
225+
getitem_1: "f32[3, 3]" = invoke_quant_test[1]
226+
getitem_2: "f32[3, 3]" = invoke_quant_test[2]
227+
getitem_3: "f32[3, 3]" = invoke_quant_test[3]; invoke_quant_test = None
228+
return (getitem, getitem_1, getitem_2, getitem_3)
229+
230+
class subgraph_0(torch.nn.Module):
231+
def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
232+
add_: "f32[3, 3]" = l_x_.add_(1); add_ = None
233+
234+
matmul: "f32[3, 3]" = l_x_ @ l_y_
235+
sin: "f32[3, 3]" = matmul.sin(); matmul = None
236+
child: "f32[3, 3]" = sin.cos(); sin = None
237+
238+
child_1: "f32[3, 3]" = l_x_ + l_y_
239+
child_2: "f32[3, 3]" = l_x_ - l_y_
240+
241+
child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
242+
return (child, child_1, child_2, child_3)
243+
""", # noqa: B950
244+
)
245+
self.assertExpectedInline(
246+
str(find_hop_schema(bk.graphs[0], invoke_quant_test)[0]),
247+
"""invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor arg1, *, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950
248+
)
249+
elif backend == "aot_eager":
250+
self.assertEqual(len(bk.fw_graphs), 1)
251+
self.assertExpectedInline(
252+
normalize_graph(bk.fw_graphs[0]),
253+
"""\
254+
class GraphModule(torch.nn.Module):
255+
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
256+
auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
257+
_tree_spec_constant0 = self._tree_spec_constant0
258+
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, arg1 = primals_2, scheme = 'nf4', _arg0_base_index = 0, _all_bases = [primals_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = _tree_spec_constant0 = None
259+
getitem: "f32[3, 3]" = auto_functionalized_v2[0]
260+
getitem_1: "f32[3, 3]" = auto_functionalized_v2[1]
261+
getitem_2: "f32[3, 3]" = auto_functionalized_v2[2]
262+
getitem_3: "f32[3, 3]" = auto_functionalized_v2[3]
263+
getitem_4: "f32[3, 3]" = auto_functionalized_v2[4]; auto_functionalized_v2 = None
264+
return (getitem, getitem_1, getitem_2, getitem_3, primals_1, primals_2, getitem_4)
265+
266+
class auto_functionalized_subgraph_0(torch.nn.Module):
267+
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
268+
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(arg0_1, 1)
269+
mm: "f32[3, 3]" = torch.ops.aten.mm.default(add, arg1_1)
270+
sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm); mm = None
271+
cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin); sin = None
272+
add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, arg1_1)
273+
sub: "f32[3, 3]" = torch.ops.aten.sub.Tensor(add, arg1_1)
274+
mm_1: "f32[3, 3]" = torch.ops.aten.mm.default(add, 10000 arg1_1); arg1_1 = None
275+
copy_: "f32[3, 3]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
276+
return (cos, add_1, sub, mm_1)
277+
""", # noqa: B950
278+
)
173279

174280
def test_none_input(self):
175281
def inner(x, y):
@@ -239,6 +345,49 @@ def forward(self, l_y_: "f32[3, 4]"):
239345
""",
240346
)
241347

348+
def test_auto_functionalize(self):
349+
def inner(x, y):
350+
x.add_(1)
351+
return x + y
352+
353+
backend = AotEagerAndRecordGraphs()
354+
355+
def f(x, y):
356+
return invoke_quant_test(inner, x, y, scheme="nf4")
357+
358+
x = torch.randn(3, 3, requires_grad=False)
359+
x_clone = x.clone()
360+
y = torch.randn(3, 3, requires_grad=True)
361+
with mock.patch(
362+
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
363+
True,
364+
):
365+
compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y)
366+
# assert x is not mutated
367+
self.assertEqual(x, x_clone)
368+
self.assertEqual(compiled_out, x + y + 1)
369+
self.assertEqual(len(backend.fw_graphs), 1)
370+
self.assertExpectedInline(
371+
normalize_graph(backend.fw_graphs[0]),
372+
"""\
373+
class GraphModule(torch.nn.Module):
374+
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
375+
auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
376+
_tree_spec_constant0 = self._tree_spec_constant0
377+
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, arg1 = primals_2, scheme = 'nf4', _arg0_base_index = 0, _all_bases = [primals_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = _tree_spec_constant0 = None
378+
getitem: "f32[3, 3]" = auto_functionalized_v2[0]
379+
getitem_1: "f32[3, 3]" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
380+
return (getitem, primals_1, primals_2, getitem_1)
381+
382+
class auto_functionalized_subgraph_0(torch.nn.Module):
383+
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
384+
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(arg0_1, 1)
385+
add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, arg1_1); arg1_1 = None
386+
copy_: "f32[3, 3]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
387+
return (add_1,)
388+
""", # noqa: B950
389+
)
390+
242391
@torch._dynamo.config.patch(assume_static_by_default=True)
243392
def test_aot_eager(self):
244393
def inner(x, y):
@@ -353,6 +502,9 @@ def inner(x, y):
353502
invoke_quant_test(result, x, y, scheme="nf4")
354503

355504

505+
instantiate_parametrized_tests(BaseHOPTest)
506+
507+
356508
if __name__ == "__main__":
357509
from torch._dynamo.test_case import run_tests
358510

test/inductor/test_foreach.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sys
44
import unittest
5+
import unittest.mock as mock
56

67
import torch
78
import torch._inductor
@@ -49,6 +50,11 @@ def add_op(x, y):
4950
return torch.add(x, y)
5051

5152

53+
def add_inplace_op(x, y):
54+
x.add_(y)
55+
return x.sin()
56+
57+
5258
def addrecip_op(x, y):
5359
return torch.reciprocal(torch. F438 add(x, y))
5460

@@ -77,6 +83,7 @@ def recipaddmul_op(x, y, z):
7783

7884
# More general functions
7985
foreach_map_add_fn = foreach_map_wrapper(add_op)
86+
foreach_map_add_inplace = foreach_map_wrapper(add_inplace_op)
8087
foreach_map_recipaddmul = foreach_map_wrapper(addrecip_op)
8188
foreach_map_addcmul = foreach_map_wrapper(addcmul_op)
8289
foreach_map_recipaddmul = foreach_map_wrapper(recipaddmul_op)
@@ -1029,6 +1036,42 @@ def ref_fn(xs, ys):
10291036

10301037
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)
10311038

1039+
@requires_cuda
1040+
def test_foreach_map_input_mutation(self):
1041+
def fn(xs, ys):
1042+
outs = foreach_map_add_inplace(xs, ys)
1043+
return outs[0].sum() + outs[1].sum() + outs[2].sum()
1044+
1045+
ref_inps = (
1046+
[
1047+
torch.rand(10, 20, device="cuda:0", requires_grad=True),
1048+
torch.rand(10, 30, device="cuda:0", requires_grad=True),
1049+
torch.rand(30, 30, device="cuda:0", requires_grad=True),
1050+
],
1051+
[
1052+
torch.rand(10, 20, device="cuda:0", requires_grad=True),
1053+
torch.rand(10, 30, device="cuda:0", requires_grad=True),
1054+
torch.rand(30, 30, device="cuda:0", requires_grad=True),
1055+
],
1056+
)
1057+
# Set requires_grad to be False to avoid mutating a leaf variable
1058+
inps = (
1059+
[x.clone().detach().requires_grad_(False) for x in ref_inps[0]],
1060+
[y.clone().detach().requires_grad_(False) for y in ref_inps[1]],
1061+
)
1062+
1063+
# TODO: after decomposing auto_functionalized, we're getting
1064+
# a functional subgraph with an inlined epilogue.
1065+
with self.assertRaisesRegex(
1066+
torch._inductor.exc.InductorError,
1067+
"Buffer mutation detected during lowering of aten.copy_.default",
1068+
):
1069+
with mock.patch(
1070+
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
1071+
True,
1072+
):
1073+
_ = run_fw_bw_and_get_code(lambda: torch.compile(fn)(*inps))
1074+
10321075
@requires_cuda
10331076
@foreach_map_un_ops
10341077
def test_foreach_map_backward_unary(self, op):

0 commit comments

Comments
 (0)
0