8000 auto functionalize base_hop by ydwu4 · Pull Request #151067 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

auto functionalize base_hop #151067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 36 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1a076cf
autofunctionalize base_hop
ydwu4 Apr 10, 2025
6db4a45
Update on "autofunctionalize base_hop"
ydwu4 Apr 11, 2025
dc77c64
Update on "autofunctionalize base_hop"
ydwu4 Apr 11, 2025
d8baef2
Update on "auto functionalize base_hop"
ydwu4 Apr 12, 2025
ad3cdd9
Update on "auto functionalize base_hop"
ydwu4 Apr 12, 2025
e044891
Update on "auto functionalize base_hop"
ydwu4 Apr 14, 2025
3514e31
Update on "auto functionalize base_hop"
ydwu4 Apr 14, 2025
2c27c6e
Update on "auto functionalize base_hop"
ydwu4 Apr 16, 2025
60c7747
Update on "auto functionalize base_hop"
ydwu4 Apr 17, 2025
d41a2f3
Update on "auto functionalize base_hop"
ydwu4 Apr 22, 2025
603431b
Update on "auto functionalize base_hop"
ydwu4 Apr 24, 2025
077586d
Update on "auto functionalize base_hop"
ydwu4 Apr 24, 20 8000 25
cb7a3c1
Update on "auto functionalize base_hop"
ydwu4 Apr 24, 2025
fd7331e
Update on "auto functionalize base_hop"
ydwu4 Apr 24, 2025
0664e7e
Update on "auto functionalize base_hop"
ydwu4 Apr 24, 2025
6800663
Update on "auto functionalize base_hop"
ydwu4 Apr 24, 2025
0f92f1e
Update on "auto functionalize base_hop"
ydwu4 Apr 24, 2025
4d13e71
Update on "auto functionalize base_hop"
ydwu4 Apr 25, 2025
06ea99c
Update on "auto functionalize base_hop"
ydwu4 Apr 25, 2025
13d7470
Update on "auto functionalize base_hop"
ydwu4 Apr 25, 2025
6a2f28c
Update on "auto functionalize base_hop"
ydwu4 Apr 25, 2025
951b961
Update on "auto functionalize base_hop"
ydwu4 Apr 25, 2025
2af2d3e
Update on "auto functionalize base_hop"
ydwu4 Apr 26, 2025
a4121fc
Update on "auto functionalize base_hop"
ydwu4 Apr 27, 2025
6461a1e
Update on "auto functionalize base_hop"
ydwu4 Apr 28, 2025
390bcf2
Update on "auto functionalize base_hop"
ydwu4 Apr 28, 2025
1b8741e
Update on "auto functionalize base_hop"
ydwu4 Apr 28, 2025
d5db6fa
Update on "auto functionalize base_hop"
ydwu4 Apr 30, 2025
a8bc21d
Update on "auto functionalize base_hop"
ydwu4 May 1, 2025
0b0d8ae
Update on "auto functionalize base_hop"
ydwu4 May 1, 2025
92f00ac
Update on "auto functionalize base_hop"
ydwu4 May 1, 2025
e8e0653
Update on "auto functionalize base_hop"
ydwu4 May 6, 2025
473cd91
Update on "auto functionalize base_hop"
ydwu4 May 6, 2025
e779246
Update on "auto functionalize base_hop"
ydwu4 May 6, 2025
0b7695b
Update on "auto functionalize base_hop"
ydwu4 May 6, 2025
e2ecbb5
Update on "auto functionalize base_hop"
ydwu4 May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 164 additions & 12 deletions test/dynamo/test_base_hop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import unittest
import unittest.mock as mock

import torch
import torch._dynamo.test_case
Expand All @@ -11,6 +12,10 @@
normalize_gm,
)
from torch._higher_order_ops.schema import find_hop_schema
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import HAS_CUDA


Expand Down Expand Up @@ -135,17 +140,47 @@ def inner(x, y):

backend = EagerAndRecordGraphs()

@torch.compile(backend=backend, fullgraph=True)
def f(x, y):
return invoke_quant_test(inner, x, y, scheme="nf4")

with self.assertRaisesRegex(
RuntimeError,
"Encountered input mutation during higher order op tracing for HOP",
with mock.patch(
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
True,
):
f(x.clone(), y)
torch.compile(f, backend=backend, fullgraph=True)(x.clone(), y)
self.assertEqual(len(backend.graphs), 1)
self.assertExpectedInline(
normalize_graph(backend.graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
l_x_ = L_x_
l_y_ = L_y_

def test_schema_gen_pytree_in_out_with_mutation(self):
subgraph_0 = self.subgraph_0
invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); subgraph_0 = l_x_ = l_y_ = None
getitem: "f32[3, 3]" = invoke_quant_test[0]; invoke_quant_test = None
return (getitem,)

class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
add_: "f32[3, 3]" = l_x_.add_(1); add_ = None

mul_: "f32[3, 3]" = l_y_.mul_(-1); mul_ = None

matmul: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
sin: "f32[3, 3]" = matmul.sin(); matmul = None
cos: "f32[3, 3]" = sin.cos(); sin = None
return (cos,)
""", # noqa: B950
)
self.assertExpectedInline(
str(find_hop_schema(backend.graphs[0], invoke_quant_test)[0]),
"""invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor(a2!) arg1, *, str scheme="nf4") -> ((Tensor))""",
)

@parametrize("backend", ["eager", "aot_eager"])
def test_schema_gen_pytree_in_out_with_mutation(self, backend):
def inner(x_y):
x, y = x_y
x.add_(1)
Expand All @@ -159,17 +194,88 @@ def inner(x_y):
x = torch.randn(3, 3, requires_grad=False)
y = torch.randn(3, 3, requires_grad=True)

backend = EagerAndRecordGraphs()
if backend == "eager":
bk = EagerAndRecordGraphs()
else:
assert backend == "aot_eager"
bk = AotEagerAndRecordGraphs()

@torch.compile(backend=backend, fullgraph=True)
def f(x, y):
return invoke_quant_test(inner, [x, y], scheme="nf4")

with self.assertRaisesRegex(
RuntimeError,
"Encountered input mutation during higher order op tracing for HOP",
with mock.patch(
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
True,
):
f(x.clone(), y)
torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y)

if backend == "eager":
self.assertEqual(len(bk.graphs), 1)
self.assertExpectedInline(
normalize_graph(bk.graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
l_x_ = L_x_
l_y_ = L_y_

subgraph_0 = self.subgraph_0
invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); subgraph_0 = l_x_ = l_y_ = None
getitem: "f32[3, 3]" = invoke_quant_test[0]
getitem_1: "f32[3, 3]" = invoke_quant_test[1]
getitem_2: "f32[3, 3]" = invoke_quant_test[2]
getitem_3: "f32[3, 3]" = invoke_quant_test[3]; invoke_quant_test = None
return (getitem, getitem_1, getitem_2, getitem_3)

class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
add_: "f32[3, 3]" = l_x_.add_(1); add_ = None

matmul: "f32[3, 3]" = l_x_ @ l_y_
sin: "f32[3, 3]" = matmul.sin(); matmul = None
child: "f32[3, 3]" = sin.cos(); sin = None

child_1: "f32[3, 3]" = l_x_ + l_y_
child_2: "f32[3, 3]" = l_x_ - l_y_

child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (child, child_1, child_2, child_3)
""", # noqa: B950
)
self.assertExpectedInline(
str(find_hop_schema(bk.graphs[0], invoke_quant_test)[0]),
"""invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor arg1, *, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950
)
elif backend == "aot_eager":
self.assertEqual(len(bk.fw_graphs), 1)
self.assertExpectedInline(
normalize_graph(bk.fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
_tree_spec_constant0 = self._tree_spec_constant0
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, arg1 = primals_2, scheme = 'nf4', _arg0_base_index = 0, _all_bases = [primals_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = _tree_spec_constant0 = None
getitem: "f32[3, 3]" = auto_functionalized_v2[0]
getitem_1: "f32[3, 3]" = auto_functionalized_v2[1]
getitem_2: "f32[3, 3]" = auto_functionalized_v2[2]
getitem_3: "f32[3, 3]" = auto_functionalized_v2[3]
getitem_4: "f32[3, 3]" = auto_functionalized_v2[4]; auto_functionalized_v2 = None
return (getitem, getitem_1, getitem_2, getitem_3, primals_1, primals_2, getitem_4)

class auto_functionalized_subgraph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[3, 3]" = torch.ops.aten.mm.default(add, arg1_1)
sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm); mm = None
cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin); sin = None
add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, arg1_1)
sub: "f32[3, 3]" = torch.ops.aten.sub.Tensor(add, arg1_1)
mm_1: "f32[3, 3]" = torch.ops.aten.mm.default(add, arg1_1); arg1_1 = None
copy_: "f32[3, 3]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
return (cos, add_1, sub, mm_1)
""", # noqa: B950
)

def test_none_input(self):
def inner(x, y):
Expand Down Expand Up @@ -239,6 +345,49 @@ def forward(self, l_y_: "f32[3, 4]"):
""",
)

def test_auto_functionalize(self):
def inner(x, y):
x.add_(1)
return x + y

backend = AotEagerAndRecordGraphs()

def f(x, y):
return invoke_quant_test(inner, x, y, scheme="nf4")

x = torch.randn(3, 3, requires_grad=False)
x_clone = x.clone()
y = torch.randn(3, 3, requires_grad=True)
with mock.patch(
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
True,
):
compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y)
# assert x is not mutated
self.assertEqual(x, x_clone)
self.assertEqual(compiled_out, x + y + 1)
self.assertEqual(len(backend.fw_graphs), 1)
self.assertExpectedInline(
normalize_graph(backend.fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
_tree_spec_constant0 = self._tree_spec_constant0
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, arg1 = primals_2, scheme = 'nf4', _arg0_base_index = 0, _all_bases = [primals_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = _tree_spec_constant0 = None
getitem: "f32[3, 3]" = auto_functionalized_v2[0]
getitem_1: "f32[3, 3]" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
return (getitem, primals_1, primals_2, getitem_1)

class auto_functionalized_subgraph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(arg0_1, 1)
add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, arg1_1); arg1_1 = None
copy_: "f32[3, 3]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
return (add_1,)
""", # noqa: B950
)

@torch._dynamo.config.patch(assume_static_by_default=True)
def test_aot_eager(self):
def inner(x, y):
Expand Down Expand Up @@ -352,6 +501,9 @@ def inner(x, y):
invoke_quant_test(result, x, y, scheme="nf4")


instantiate_parametrized_tests(BaseHOPTest)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

Expand Down
43 changes: 43 additions & 0 deletions test/inductor/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
import unittest
import unittest.mock as mock

import torch
import torch._inductor
Expand Down Expand Up @@ -49,6 +50,11 @@ def add_op(x, y):
return torch.add(x, y)


def add_inplace_op(x, y):
x.add_(y)
return x.sin()


def addrecip_op(x, y):
return torch.reciprocal(torch.add(x, y))

Expand Down Expand Up @@ -77,6 +83,7 @@ def recipaddmul_op(x, y, z):

# More general functions
foreach_map_add_fn = foreach_map_wrapper(add_op)
foreach_map_add_inplace = foreach_map_wrapper(add_inplace_op)
foreach_map_recipaddmul = foreach_map_wrapper(addrecip_op)
foreach_map_addcmul = foreach_map_wrapper(addcmul_op)
foreach_map_recipaddmul = foreach_map_wrapper(recipaddmul_op)
Expand Down Expand Up @@ -1029,6 +1036,42 @@ def ref_fn(xs, ys):

self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)

@requires_cuda
def test_foreach_map_input_mutation(self):
def fn(xs, ys):
outs = foreach_map_add_inplace(xs, ys)
return outs[0].sum() + outs[1].sum() + outs[2].sum()

ref_inps = (
[
torch.rand(10, 20, device="cuda:0", requires_grad=True),
torch.rand(10, 30, device="cuda:0", requires_grad=True),
torch.rand(30, 30, device="cuda:0", requires_grad=True),
],
[
torch.rand(10, 20, device="cuda:0", requires_grad=True),
torch.rand(10, 30, device="cuda:0", requires_grad=True),
torch.rand(30, 30, device="cuda:0", requires_grad=True),
],
)
# Set requires_grad to be False to avoid mutating a leaf variable
inps = (
[x.clone().detach().requires_grad_(False) for x in ref_inps[0]],
[y.clone().detach().requires_grad_(False) for y in ref_inps[1]],
)

# TODO: after decomposing auto_functionalized, we're getting
# a functional subgraph with an inlined epilogue.
with self.assertRaisesRegex(
torch._inductor.exc.InductorError,
"Buffer mutation detected during lowering of aten.copy_.default",
):
with mock.patch(
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
True,
):
_ = run_fw_bw_and_get_code(lambda: torch.compile(fn)(*inps))

@requires_cuda
@foreach_map_un_ops
def test_foreach_map_backward_unary(self, op):
Expand Down
Loading
Loading
0