8000 [HOP] Mutation and alias rework (#146658) · pytorch/pytorch@6803419 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit 6803419

Browse files
bohnstinglpytorchmergebot
authored andcommitted
[HOP] Mutation and alias rework (#146658)
This PR reworks the way the input mutations and various aliases are checked Pull Request resolved: #146658 Approved by: https://github.com/ydwu4
1 parent 0e805aa commit 6803419

File tree

20 files changed

+516
-266
lines changed

20 files changed

+516
-266
lines changed

functorch/experimental/control_flow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from torch import cond # noqa: F401
2-
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
32
from torch._higher_order_ops.map import ( # noqa: F401
43
_stack_pytree,
54
_unstack_pytree,

test/dynamo/test_export.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,7 +1873,7 @@ def true_fn(x):
18731873
return x + x
18741874

18751875
def false_fn(x):
1876-
return x[:2]
1876+
return x[:2].clone()
18771877

18781878
return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
18791879

@@ -1883,7 +1883,7 @@ def true_fn(x):
18831883
return x + x
18841884

18851885
def false_fn(x):
1886-
return x[:2]
1886+
return x[:2].clone()
18871887

18881888
return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
18891889

@@ -1924,7 +1924,8 @@ def forward(self, l_x_):
19241924
def forward(self, l_x_):
19251925
l_x__1 = l_x_
19261926
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
1927-
return (getitem,)""",
1927+
clone = getitem.clone(); getitem = None
1928+
return (clone,)""",
19281929
)
19291930
# We could successfully export branches that return different sizes
19301931
torch._dynamo.export(mod)(torch.randn(3, 2))
@@ -3302,7 +3303,12 @@ def f_branch_return_non_tensor(x):
33023303

33033304
def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
33043305
def f_branch_return_multiple_tensors(pred, x, y):
3305-
return cond(pred, lambda x: (x, x), lambda x: (x, x), [y])
3306+
return cond(
3307+
pred,
3308+
lambda x: (x.clone(), x.clone()),
3309+
lambda x: (x.clone(), x.clone()),
3310+
[y],
3311+
)
33063312

33073313
example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
33083314
gm, _ = torch._dynamo.export(
@@ -3324,10 +3330,10 @@ def forward(self, x):
33243330

33253331
def test_cond_raise_user_error_on_mismatch_return_length(self):
33263332
def true_fn(x):
3327-
return x
3333+
return x.clone()
33283334

33293335
def false_fn(x):
3330-
return (x, x)
3336+
return (x.clone(), x.clone())
33313337

33323338
def f_mismatch_return_length(x):
33333339
return cond(torch.tensor(100), true_fn, false_fn, [x])

test/dynamo/test_higher_order_ops.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,13 @@ def forward(self, child : torch.Tensor):
17911791

17921792
def test_map_pytree_return(self):
17931793
def _construct_pytree(a):
1794-
return (a, [[[a]]], a, (a, (a,), a), {"a": a})
1794+
return (
1795+
a.clone(),
1796+
[[[a.clone()]]],
1797+
a.clone(),
1798+
(a.clone(), (a.clone(),), a.clone()),
1799+
{"a": a.clone()},
1800+
)
17951801

17961802
def f(x):
17971803
def inner_f(xs):
@@ -1823,7 +1829,14 @@ def forward(self, L_x_ : torch.Tensor):
18231829
body_graph,
18241830
"""\
18251831
def forward(self, child : torch.Tensor):
1826-
return (child, child, child, child, child, child, child)""",
1832+
child_1 = child.clone()
1833+
child_2 = child.clone()
1834+
child_3 = child.clone()
1835+
child_4 = child.clone()
1836+
child_5 = child.clone()
1837+
child_6 = child.clone()
1838+
child_7 = child.clone(); child = None
1839+
return (child_1, child_2, child_3, child_4, child_5, child_6, child_7)""",
18271840
)
18281841

18291842
def test_map_kwargs(self):
@@ -6902,7 +6915,7 @@ def test_cond_with_kwargs(self):
69026915

69036916
def test(pred, x):
69046917
def true_fn(x):
6905-
return x
6918+
return x.clone()
69066919

69076920
def false_fn(x):
69086921
return -x
@@ -6926,7 +6939,7 @@ def test_cond_with_invalid_kwargs(self):
69266939

69276940
def test(pred, mode, x):
69286941
def true_fn(x):
6929-
return x
6942+
return x.clone()
69306943

69316944
def false_fn(x):
69326945
return -x

test/dynamo/test_misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5931,7 +5931,7 @@ def test_cond_export_single_arg(self):
59315931
from functorch.experimental.control_flow import cond
59325932

59335933
def true_fn(x):
5934-
return x
5934+
return x.clone()
59355935

59365936
def false_fn(x):
59375937
return x.sin()

test/export/test_export.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7604,20 +7604,25 @@ def forward(self, x):
76047604
self.assertTrue(torch.allclose(ep.module()(xs), module_out))
76057605

76067606
@requires_cuda
7607-
@testing.expectedFailureCppRuntime
76087607
def test_export_associative_scan_lifted_buffers(self):
76097608
device = torch.device("cuda")
76107609
combine_mode = "pointwise"
76117610

7611+
class A(torch.nn.Module):
7612+
def __init__(self) -> None:
7613+
super().__init__()
7614+
self.buffer = torch.nn.Buffer(torch.ones(3, 2, device=device))
7615+
7616+
def forward(self):
7617+
return self.buffer.cos()
7618+
76127619
class M(torch.nn.Module):
76137620
def __init__(self) -> None:
76147621
super().__init__()
7615-
self.register_buffer(
7616-
"buf", torch.ones(3, 2, device=device), persistent=False
7617-
)
7622+
self.a = A()
76187623

76197624
def combine_fn(self, x, y):
7620-
return x + y * self.buf
7625+
return (x + y) * self.a()
76217626

76227627
def forward(self, x):
76237628
return associative_scan(

test/functorch/test_aotdispatch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4572,17 +4572,17 @@ def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"):
45724572
45734573
body_graph_0 = self.body_graph_0
45744574
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None
4575-
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None
4575+
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
45764576
4577-
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
4577+
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
45784578
45794579
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None
45804580
45814581
body_graph_1 = self.body_graph_1
45824582
map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None
4583-
getitem_1: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
4583+
getitem_5: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
45844584
4585-
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
4585+
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_5); getitem_5 = None
45864586
45874587
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
45884588
return (add_1,)
@@ -4635,9 +4635,9 @@ def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"):
46354635
46364636
body_graph_0 = self.body_graph_0
46374637
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None
4638-
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None
4638+
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
46394639
4640-
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
4640+
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
46414641
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
46424642
return (add,)
46434643

0 commit comments

Comments
 (0)
0