8000 auto functionalize base_hop · pytorch/pytorch@419d795 · GitHub
[go: up one dir, main page]

Skip to content

Commit 419d795

Browse files
committed
auto functionalize base_hop
ghstack-source-id: 5b120c8 Pull Request resolved: #151067
1 parent d3c6f5d commit 419d795

File tree

11 files changed

+497
-111
lines changed

11 files changed

+497
-111
lines changed

test/dynamo/test_base_hop.py

Lines changed: 192 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,66 @@ def __init__(self):
2828
def __call__(self, subgraph, *operands, scheme):
2929
return super().__call__(subgraph, *operands, scheme=scheme)
3030

31+
def gen_schema(self, subgraph, *operands, scheme):
32+
# Idea 1: using inspect.signature and sample inputs to generate a schema
33+
# Idea 2: we still need to know how to call into subgraph/fn given the inputs.
34+
# wrap_subgraphs gives two callable to call into subgraph.
35+
from torch._higher_order_ops.schema import (
36+
CFunctionSchemaGen,
37+
HopArgumentInfoGen,
38+
)
39+
from torch._higher_order_ops.utils import (
40+
check_input_alias_and_mutation_return_ouputs,
41+
)
42+
43+
(
44+
mutated_inp_idx,
45+
inp_inp_alias,
46+
inp_out_alias,
47+
out_ 8000 out_alias,
48+
output,
49+
) = check_input_alias_and_mutation_return_ouputs(subgraph, operands)
50+
assert (
51+
len(inp_inp_alias) == 0
52+
and len(inp_out_alias) == 0
53+
and len(out_out_alias) == 0
54+
), f"Aliasing is not suppported for HOP subgraph. {subgraph}"
55+
56+
args = [
57+
HopArgumentInfoGen.from_example(
58+
subgraph, name="subgraph", default_value=None, is_mutated=False
59+
)
60+
]
61+
for idx, arg in enumerate(operands):
62+
example_value = arg
63+
arg_name = f"operands{idx}"
64+
args.append(
65+
HopArgumentInfoGen.from_example(
66+
example_value=example_value,
67+
name=arg_name,
68+
default_value=None,
69+
is_mutated=idx in mutated_inp_idx,
70+
)
71+
)
72+
73+
args.append(
74+
HopArgumentInfoGen.from_example(
75+
example_value=scheme,
76+
name="scheme",
77+
default_value=scheme,
78+
is_mutated=False,
79+
kw_only=True,
80+
)
81+
)
82+
output = HopArgumentInfoGen.from_example(
83+
example_value=output,
84+
name="output",
85+
default_value=None,
86+
is_mutated=False,
87+
kw_only=False,
88+
)
89+
return CFunctionSchemaGen.from_hop_argument_info(str(self), args, output)
90+
3191

3292
invoke_quant_test = InvokeQuantTest()
3393

@@ -93,7 +153,7 @@ def f(x, y):
93153
self.assertEqual(len(schemas), 1)
94154
self.assertExpectedInline(
95155
str(schemas[0]),
96-
"""invoke_quant_test(Any subgraph, Tensor arg0, Tensor arg1, str scheme="nf4") -> ((Tensor))""", # noqa: B950
156+
"""invoke_quant_test(Any subgraph, Tensor operands0, Tensor operands1, *, str scheme="nf4") -> ((Tensor))""", # noqa: B950
97157
)
98158

99159
def test_schema_gen_pytree_in_out(self):
@@ -121,7 +181,7 @@ def f(x, y):
121181
self.assertEqual(len(schemas), 1)
122182
self.assertExpectedInline(
123183
str(schemas[0]),
124-
"""invoke_quant_test(Any subgraph, Tensor arg0, Tensor arg1, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950
184+
"""invoke_quant_test(Any subgraph, Tensor operands0, Tensor operands1, *, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950
125185
)
126186

127187
def test_schema_gen_single_return_with_mutation(self):
@@ -135,15 +195,40 @@ def inner(x, y):
135195

136196
backend = EagerAndRecordGraphs()
137197

138-
@torch.compile(backend=backend, fullgraph=True)
139198
def f(x, y):
140199
return invoke_quant_test(inner, x, y, scheme="nf4")
141200

142-
with self.assertRaisesRegex(
143-
RuntimeError,
144-
"Encountered input mutation during higher order op tracing for HOP",
145-
):
146-
f(x.clone(), y)
201+
torch.compile(f, backend=backend, fullgraph=True)(x.clone(), y)
202+
self.assertEqual(len(backend.graphs), 1)
203+
self.assertExpectedInline(
204+
normalize_graph(backend.graphs[0]),
205+
"""\
206+
class GraphModule(torch.nn.Module):
207+
def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
208+
l_x_ = L_x_
209+
l_y_ = L_y_
210+
211+
subgraph_0 = self.subgraph_0
212+
invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); subgraph_0 = l_x_ = l_y_ = None
213+
getitem: "f32[3, 3]" = invoke_quant_test[0]; invoke_quant_test = None
214+
return (getitem,)
215+
216+
class subgraph_0(torch.nn.Module):
217+
def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
218+
add_: "f32[3, 3]" = l_x_.add_(1); add_ = None
219+
220+
mul_: "f32[3, 3]" = l_y_.mul_(-1); mul_ = None
221+
222+
matmul: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
223+
sin: "f32[3, 3]" = matmul.sin(); matmul = None
224+
cos: "f32[3, 3]" = sin.cos(); sin = None
225+
return (cos,)
226+
""", # noqa: B950
227+
)
228+
self.assertExpectedInline(
229+
str(find_hop_schema(backend.graphs[0], invoke_quant_test)[0]),
230+
"""invoke_quant_test(Any subgraph, Tensor(a1!) operands0, Tensor(a2!) operands1, *, str scheme="nf4") -> ((Tensor))""",
231+
)
147232

148233
def test_schema_gen_pytree_in_out_with_mutation(self):
149234
def inner(x_y):
@@ -161,15 +246,46 @@ def inner(x_y):
161246

162247
backend = EagerAndRecordGraphs()
163248

164-
@torch.compile(backend=backend, fullgraph=True)
165249
def f(x, y):
166250
return invoke_quant_test(inner, [x, y], scheme="nf4")
167251

168-
with self.assertRaisesRegex(
169-
RuntimeError,
170-
"Encountered input mutation during higher order op tracing for HOP",
171-
):
172-
f(x.clone(), y)
252+
torch.compile(f, backend=backend, fullgraph=True)(x.clone(), y)
253+
self.assertEqual(len(backend.graphs), 1)
254+
self.assertExpectedInline(
255+
normalize_graph(backend.graphs[0]),
256+
"""\
257+
class GraphModule(torch.nn.Module):
258+
def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
259+
l_x_ = L_x_
260+
l_y_ = L_y_
261+
262+
subgraph_0 = self.subgraph_0
263+
invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); subgraph_0 = l_x_ = l_y_ = None
264+
getitem: "f32[3, 3]" = invoke_quant_test[0]
265+
getitem_1: "f32[3, 3]" = invoke_quant_test[1]
266+
getitem_2: "f32[3, 3]" = invoke_quant_test[2]
267+
getitem_3: "f32[3, 3]" = invoke_quant_test[3]; invoke_quant_test = None
268+
return (getitem, getitem_1, getitem_2, getitem_3)
269+
270+
class subgraph_0(torch.nn.Module):
271+
def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
272+
add_: "f32[3, 3]" = l_x_.add_(1); add_ = None
273+
274+
matmul: "f32[3, 3]" = l_x_ @ l_y_
275+
sin: "f32[3, 3]" = matmul.sin(); matmul = None
276+
child: "f32[3, 3]" = sin.cos(); sin = None
277+
278+
child_1: "f32[3, 3]" = l_x_ + l_y_
279+
child_2: "f32[3, 3]" = l_x_ - l_y_
280+
281+
child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
282+
return (child, child_1, child_2, child_3)
283+
""", # noqa: B950
284+
)
285+
self.assertExpectedInline(
286+
str(find_hop_schema(backend.graphs[0], invoke_quant_test)[0]),
287+
"""invoke_quant_test(Any subgraph, Tensor(a1!) operands0, Tensor operands1, *, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950
288+
)
173289

174290
def test_none_input(self):
175291
def inner(x, y):
@@ -239,6 +355,44 @@ def forward(self, l_y_: "f32[3, 4]"):
239355
""",
240356
)
241357

358+
def test_auto_functionalize(self):
359+
def inner(x, y):
360+
x.add_(1)
361+
return x + y
362+
363+
backend = AotEagerAndRecordGraphs()
364+
365+
def f(x, y):
366+
return invoke_quant_test(inner, x, y, scheme="nf4")
367+
368+
x = torch.randn(3, 3, requires_grad=False)
369+
x_clone = x.clone()
370+
y = torch.randn(3, 3, requires_grad=True)
371+
compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y)
372+
# assert x is not mutated
373+
self.assertEqual(x, x_clone)
374+
self.assertEqual(compiled_out, x + y + 1)
375+
self.assertEqual(len(backend.fw_graphs), 1)
376+
self.assertExpectedInline(
377+
normalize_graph(backend.fw_graphs[0]),
378+
"""\
379+
class GraphModule(torch.nn.Module):
380+
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
381+
functiona_schema_0 = self.functiona_schema_0
382+
auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
383+
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, operands1 = primals_2, scheme = 'nf4', _operands0_base_index = 0, _all_bases = [primals_1], _op_schema = functiona_schema_0); auto_functionalized_subgraph_0 = functiona_schema_0 = None
384+
getitem: "f32[3, 3]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None
385+
return (getitem, primals_1, primals_2)
386+
387+
class auto_functionalized_subgraph_0(torch.nn.Module):
388+
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
389+
add_: "f32[3, 3]" = torch.ops.aten.add_.Tensor(arg0_1, 1); arg0_1 = None
390+
391+
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(add_, arg1_1); add_ = arg1_1 = None
392+
return (add,)
393+
""", # noqa: B950
394+
)
395+
242396
@torch._dynamo.config.patch(assume_static_by_default=True)
243397
def test_aot_eager(self):
244398
def inner(x, y):
@@ -265,16 +419,17 @@ def f(x, y):
265419
"""\
266420
class GraphModule(torch.nn.Module):
267421
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
268-
subgraph0 = self.subgraph0
269-
invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph0, primals_1, primals_2, scheme = 'nf4'); subgraph0 = None
270-
getitem: "f32[3, 3]" = invoke_quant_test[0]; invoke_quant_test = None
422+
functiona_schema_0 = self.functiona_schema_0
423+
auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
424+
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, operands0 = primals_1, operands1 = primals_2, scheme = 'nf4', _all_bases = [], _op_schema = functiona_schema_0); auto_functionalized_subgraph_0 = functiona_schema_0 = None
425+
getitem: "f32[3, 3]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None
271426
return (getitem, primals_1, primals_2)
272427
273-
class subgraph0(torch.nn.Module):
428+
class auto_functionalized_subgraph_0(torch.nn.Module):
274429
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
275430
mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg0_1, arg1_1); arg0_1 = arg1_1 = None
276-
sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm); mm = None
277-
cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin); sin = None
431+
sin_: "f32[3, 3]" = torch.ops.aten.sin_.default(mm); mm = None
432+
cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin_); sin_ = None
278433
return (cos,)
279434
""", # NOQA: B950
280435
)
@@ -285,20 +440,21 @@ def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
285440
"""\
286441
class GraphModule(torch.nn.Module):
287442
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]", tangents_1: "f32[3, 3]"):
288-
subgraph1 = self.subgraph1
289-
invoke_quant_test_1 = torch.ops.higher_order.invoke_quant_test(subgraph1, primals_1, primals_2, tangents_1, scheme = 'nf4'); subgraph1 = primals_1 = primals_2 = tangents_1 = None
290-
getitem_1: "f32[3, 3]" = invoke_quant_test_1[0]
291-
getitem_2: "f32[3, 3]" = invoke_quant_test_1[1]; invoke_quant_test_1 = None
443+
functiona_schema_1 = self.functiona_schema_1
444+
auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
445+
auto_functionalized_v2_1 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_1, operands0 = primals_1, operands1 = primals_2, operands2 = tangents_1, scheme = 'nf4', _all_bases = [], _op_schema = functiona_schema_1); auto_functionalized_subgraph_1 = primals_1 = primals_2 = tangents_1 = functiona_schema_1 = None
446+
getitem_1: "f32[3, 3]" = auto_functionalized_v2_1[0]
447+
getitem_2: "f32[3, 3]" = auto_functionalized_v2_1[1]; auto_functionalized_v2_1 = None
292448
return (getitem_1, getitem_2)
293449
294-
class subgraph1(torch.nn.Module):
450+
class auto_functionalized_subgraph_1(torch.nn.Module):
295451
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]", arg2_1: "f32[3, 3]"):
296452
mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg0_1, arg1_1)
297453
clone: "f32[3, 3]" = torch.ops.aten.clone.default(mm)
298-
sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm); mm = None
299-
cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin); cos = None
300-
sin_1: "f32[3, 3]" = torch.ops.aten.sin.default(sin); sin = None
301-
neg: "f32[3, 3]" = torch.ops.aten.neg.default(sin_1); sin_1 = None
454+
sin_: "f32[3, 3]" = torch.ops.aten.sin_.default(mm); mm = None
455+
cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin_); cos = None
456+
sin: "f32[3, 3]" = torch.ops.aten.sin.default(sin_); sin_ = None
457+
neg: "f32[3, 3]" = torch.ops.aten.neg.default(sin); sin = None
302458
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg2_1, neg); arg2_1 = neg = None
303459
cos_1: "f32[3, 3]" = torch.ops.aten.cos.default(clone); clone = None
304460
mul_1: "f32[3, 3]" = torch.ops.aten.mul.Tensor(mul, cos_1); mul = cos_1 = None
@@ -320,21 +476,22 @@ def inner2(x, y):
320476

321477
x = torch.randn(3, 3)
322478
y = torch.randn(3, 3)
479+
x_clone = x.clone()
480+
y_clone = y.clone()
323481

324482
@torch.compile(backend="eager", fullgraph=True)
325483
def f(inner, x, y):
326484
return invoke_quant_test(inner, x, y, scheme="nf4")
327485

486+
compiled_f = torch.compile(f, backend="eager", fullgraph=True)
487+
328488
with self.assertRaisesRegex(
329489
RuntimeError, "Encountered aliasing during higher order op tracing for HOP"
330490
):
331-
f(inner, x, y)
491+
compiled_f(inner, x, y)
332492

333-
with self.assertRaisesRegex(
334-
RuntimeError,
335-
"Encountered input mutation during higher order op tracing for HOP",
336-
):
337-
f(inner2, x, y)
493+
compiled_out = compiled_f(inner2, x, y)
494+
self.assertEqual(compiled_out, f(inner2, x_clone, y_clone))
338495

339496
def test_eager_call(self):
340497
def inner(x, y):

test/higher_order_ops/test_invoke_quant.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from torch._higher_order_ops import InvokeQuant
1414
from torch._inductor import config
1515
from torch._inductor.pattern_matcher import (
16-
Arg,
1716
CallFunction,
1817
Ignored,
1918
Match,
@@ -119,9 +118,10 @@ def fn(x, y, z):
119118
logs = "\n".join(r.getMessage() for r in log.records)
120119
f = FileCheck()
121120
f.check("AFTER POST GRAD")
122-
f.check("subgraph0").ch 10000 eck("subgraph1")
123-
for _ in range(2):
124-
f.check("torch.ops.higher_order.invoke_quant(").check_same("nf4")
121+
f.check("subgraph0_1")
122+
f.check("torch.ops.higher_order.invoke_quant(").check_same("nf4")
123+
f.check("subgraph0_0")
124+
f.check("torch.ops.higher_order.invoke_quant(").check_same("nf4")
125125
f.run(logs)
126126

127127

@@ -159,15 +159,15 @@ def fn_no_match(x, y, z):
159159

160160
@register_graph_pattern(
161161
CallFunction(
162-
torch.ops.aten.mm,
163-
CallFunction(
164-
torch.ops.higher_order.invoke_quant,
165-
Ignored(),
166-
Ignored(),
167-
Ignored(),
168-
scheme="nf4",
169-
),
170-
Arg(),
162+
torch.ops.higher_order.auto_functionalized_v2,
163+
Ignored(),
164+
subgraph=Ignored(),
165+
arg0=Ignored(),
166+
arg1=Ignored(),
167+
scheme="nf4",
168+
quant_options=Ignored(),
169+
_all_bases=Ignored(),
170+
_op_schema=Ignored(),
171171
),
172172
pass_dict=test_pass,
173173
)

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3105,7 +3105,7 @@ def maybe_positional_arg_names(func):
31053105
class BaseHOPVariable(WrapHigherOrderVariable):
31063106
def __init__(self, *args, **kwargs):
31073107
super().__init__(*args, **kwargs)
3108-
self.supports_input_mutation = False
3108+
self.supports_input_mutation = True
31093109
self.supports_aliasing = False
31103110

31113111
def python_type(self):

0 commit comments

Comments
 (0)
0