8000 [HOP] Mutation and alias rework by bohnstingl · Pull Request #146658 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[HOP] Mutation and alias rework #146658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 66 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
9e8651f
WIP: rework
bohnstingl Feb 7, 2025
35e91b9
Lintrunner
bohnstingl Feb 7, 2025
0f7f261
Fixed import issue with FlexAttention
bohnstingl Feb 7, 2025
def6c92
Code cleanup
bohnstingl Feb 7, 2025
08611ca
Fixed missing import issues
bohnstingl Feb 7, 2025
83660b9
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl Feb 12, 2025
a3087fe
Integrated lifted arguments and aot_eager backend for associative_scan
bohnstingl Feb 12, 2025
0bdf2cd
WIP commit
bohnstingl Feb 13, 2025
0cfe187
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl Feb 21, 2025
c0e507d
WIP
bohnstingl Mar 1, 2025
9dbd727
Moved alias and mutation checks to dynamo for associative_scan and scan
bohnstingl Mar 1, 2025
884a18a
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl Mar 1, 2025
b171692
Integrated mutation check to cond
bohnstingl Mar 1, 2025
3a977f9
Reworked alias and mutation checks for scan, associative_scan, while_…
bohnstingl Mar 2, 2025
ca6eedb
WIP
bohnstingl Mar 4, 2025
ef2e493
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl Mar 5, 2025
8eea0f2
Integrated code reviews
bohnstingl Mar 5, 2025
ef1a56e
Fixed lint issues
bohnstingl Mar 5, 2025
0920f06
Fixed lint issues
bohnstingl Mar 5, 2025
61b71d7
WIP
bohnstingl Mar 6, 2025
eb99f6d
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl Mar 12, 2025
3947d5a
Deferred flex_attention mutation checking
bohnstingl Mar 12, 2025
6bcbc69
Fixed import issue with flex_attention
bohnstingl Mar 12, 2025
05e6e7d
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl Mar 13, 2025
d410185
Only use input mutation for flex_attention
bohnstingl Mar 13, 2025
4e67997
Fixed variable name error
bohnstingl Mar 13, 2025
9e609eb
Fixed imports
bohnstingl Mar 13, 2025
6450042
Fixed issues with mutation and alias check, i.e., pytree output neede…
bohnstingl Mar 13, 2025
23a2ed9
Fixed alias issue with cond
bohnstingl Mar 13, 2025
a23a12c
Fixed issues with CI tests for input mutation/alias. Reverted checks …
bohnstingl Mar 14, 2025
ed9e979
Unified analysis of mutation and alias
bohnstingl Mar 14, 2025
0a222ad
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl Mar 14, 2025
ac4a048
Fixed testcases
bohnstingl Mar 14, 2025
95c49ff
Lintrunner fixes
bohnstingl Mar 14, 2025
06173fb
Fixed some more testcases and cleaned resolved issue references
bohnstingl Mar 17, 2025
3cf0680
Removed unnecessary code
bohnstingl Mar 17, 2025
34d8676
Update to executorch_call_delegate
bohnstingl Mar 17, 2025
a8c5b11
Fixed issue with lifted arguments in mutation checks for scan and ass…
bohnstingl Mar 17, 2025
c05af7d
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl Mar 17, 2025
12d9922
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl Apr 24, 2025
1750287
Updates to HOPs and to alias checking speculate_subgraph
bohnstingl Apr 27, 2025
9adc558
Fixed issues with new alias rework and fake inputs
bohnstingl Apr 29, 2025
e3e8200
Fixing lint issues and CI tests
bohnstingl Apr 29, 2025
e6831d9
Fixed further CI testcases and lint issues
bohnstingl Apr 30, 2025
c8131ec
Fixed cond testcase
bohnstingl Apr 30, 2025
781291f
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl Apr 30, 2025
48483a1
Reverted executorch HOP
bohnstingl Apr 30, 2025
c92f378
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl May 2, 2025
738b649
Fixed CI tests
bohnstingl May 2, 2025
b9d261c
Fixed CI tests
bohnstingl May 2, 2025
5802346
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl May 5, 2025
15fe278
Removed unnecessary code
bohnstingl May 5, 2025
ee7fc82
Fixed merge issues
bohnstingl May 6, 2025
784969c
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl May 6, 2025
e6a763e
Fixed CI tests
bohnstingl May 6, 2025
dfe3fe1
Update invoke_subgraph graphs
bohnstingl May 6, 2025
770b761
Reverted invoke_subgraph and FunctionalizeCtxWrapper
bohnstingl May 7, 2025
7cd9e70
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl May 12, 2025
6ef62ea
Fixed issue with return order of alias and mutation function
bohnstingl May 12, 2025
089513e
Fixed cond testcases
bohnstingl May 12, 2025
58acb70
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl May 13, 2025
34367a8
Cleaned up code
bohnstingl May 13, 2025
474e86d
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl May 15, 2025
fe3b370
Fixed some CI issues and integrated review comments
bohnstingl May 15, 2025
e1d46b4
Merge branch 'main' of github.com:pytorch/pytorch into mutation_alias…
bohnstingl May 16, 2025
55b0301
Fixes to CI tests because of alias_mutation order
bohnstingl May 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
WIP: rework
  • Loading branch information
bohnstingl committed Feb 7, 2025
commit 9e8651f4087d98d9a361ed85f7a22dfcb700c91e
32 changes: 16 additions & 16 deletions test/functorch/test_control_flow.py
629A
Original file line number Diff line number Diff line change
Expand Up @@ -4710,7 +4710,7 @@ def forward(self, x_1):
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
UnsupportedAliasMutationException, "A branch or combine_fn might be modifying the input!.*"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
Expand Down Expand Up @@ -4751,7 +4751,7 @@ def forward(self, x_1):
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
UnsupportedAliasMutationException, "A branch or combine_fn might be modifying the input!.*"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
Expand Down Expand Up @@ -4785,7 +4785,7 @@ def forward(self, x_1):
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
UnsupportedAliasMutationException, "Aliasing within branch or combine_fn might be occuring!.*"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
Expand Down Expand Up @@ -4813,7 +4813,7 @@ def f(x):

example_inputs = (torch.ones(4, 5),)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
UnsupportedAliasMutationException, "A branch or combine_fn might be modifying the input!.*"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
Expand Down Expand Up @@ -4846,7 +4846,7 @@ def f(x):
f(example_input_func)

with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
UnsupportedAliasMutationException, "A branch or combine_fn might be modifying the input!.*"
):
make_fx(f, tracing_mode="symbolic")(example_input_func)
finally:
Expand All @@ -4864,7 +4864,7 @@ def wrapper(*args, **kwargs):
return wrapper

with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
UnsupportedAliasMutationException, "A branch or combine_fn might be modifying the input!.*"
):
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)

Expand All @@ -4888,7 +4888,7 @@ def f(x):
torch._enable_functionalization(reapply_views=False)
with self.assertRaisesRegex(
UnsupportedAliasMutationException,
"One of torch.cond branch might be aliasing",
"Aliasing within branch or combine_fn might be occuring!.*",
):
f(example_input_func)
finally:
Expand Down Expand Up @@ -4919,7 +4919,7 @@ def wrapper(*args, **kwargs):

with self.assertRaisesRegex(
UnsupportedAliasMutationException,
"One of torch.cond branch might be aliasing",
"Aliasing within branch or combine_fn might be occuring!.*",
):
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)

Expand Down Expand Up @@ -5513,7 +5513,7 @@ def f(xs, y):
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "torch.map is mutating the input!"
UnsupportedAliasMutationException, "A branch or combine_fn might be modifying the input!.*"
):
functional_f(*example_inputs)

Expand All @@ -5530,7 +5530,7 @@ def f(xs, y):
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "torch.map is mutating the input!"
UnsupportedAliasMutationException, "A branch or combine_fn might be modifying the input!.*"
):
functional_f(*example_inputs)

Expand Down Expand Up @@ -5570,7 +5570,7 @@ def f(xs):
example_inputs = (torch.ones(3, 2, 4),)
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "torch.map is aliasing the input!"
UnsupportedAliasMutationException, "Aliasing within branch or combine_fn might be occuring!.*"
):
functional_f(*example_inputs)

Expand Down Expand Up @@ -6462,7 +6462,7 @@ def f(init, xs):
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException,
"Combine_fn might be modifying the input!",
"A branch or combine_fn might be modifying the input!.*",
):
functional_f(example_init, example_inputs)

Expand All @@ -6476,7 +6476,7 @@ def f(init, xs):
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException,
"Combine_fn might be modifying the input!",
"A branch or combine_fn might be modifying the input!.*",
):
functional_f(example_init, example_inputs)

Expand All @@ -6493,7 +6493,7 @@ def f(init, xs):
example_init = torch.ones(5, 4)
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "Combine_fn might be aliasing the input!"
UnsupportedAliasMutationException, "Aliasing within branch or combine_fn might be occuring!.*"
):
functional_f(example_init, example_inputs)

Expand Down Expand Up @@ -7015,7 +7015,7 @@ def fn(f, *args):
x = torch.randn(2, 2)
for f in ALIAS_FN:
with self.assertRaisesRegex(
torch._dynamo.exc.BackendCompilerFailed, "might be aliasing the input"
torch._dynamo.exc.BackendCompilerFailed, ".*Aliasing within branch or combine_fn might be occuring!.*"
):
torch.compile(fn)(f, x)

Expand All @@ -7031,7 +7031,7 @@ def f(arg1, arg2):
# as a result of auto lifting.
for view_f in ALIAS_FN[1:]:
with self.assertRaisesRegex(
torch._dynamo.exc.BackendCompilerFailed, "might be aliasing the input"
torch._dynamo.exc.BackendCompilerFailed, ".*Aliasing within branch or combine_fn might be occuring!.*"
):
torch.compile(fn)(view_f, x)

Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/graph_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def _replace_region_with_subgraph(
invoke_args = (get_subgraph_node, subgraph_name, tuple(sub_args))
fake_inputs = [node.meta["example_value"] for node in sub_args]

# TODO: We don't care here about any output-output aliasing?
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
log.debug(
"NYI: Failed to substitute region %s due to input alias or mutation",
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/higher_order_ops.py
F438
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,7 @@ def call_function(
for node in body_gmod.graph.nodes
if node.op == "placeholder"
]
# TODO: We don't care here about any output-output aliasing?
if has_potential_input_alias_or_mutation(body_gmod, fake_inputs):
raise RuntimeError(
f"{self.value._name} where the inputs are mutated or the "
Expand Down
11 changes: 11 additions & 0 deletions torch/_higher_order_ops/associative_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
check_input_mutation_and_alias,
_maybe_run_with_interpreter,
_set_compilation_env,
autograd_not_implemented,
Expand Down Expand Up @@ -415,6 +416,16 @@ def associative_scan_functionalize(ctx, combine_fn, xs):
functional_combine_fn = ctx.functionalize(
_maybe_run_with_interpreter(combine_fn)
)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
sample_inputs = list(
itertools.chain(
[inp.clone() for inp in unwrapped_xs],
[inp.clone() for inp in unwrapped_xs],
)
)

check_input_mutation_and_alias(combine_fn, sample_inputs, pre_dispatch=pre_dispatch)

ret = associative_scan_op(functional_combine_fn, unwrapped_xs)
return ctx.wrap_tensors(ret)

Expand Down
20 changes: 2 additions & 18 deletions torch/_higher_order_ops/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from torch._functorch.utils import exposed_in
from torch._guards import detect_fake_mode
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
check_input_mutation_and_alias,
_maybe_run_with_interpreter,
_set_compilation_env,
reenter_make_fx,
Expand Down Expand Up @@ -483,22 +482,7 @@ def cond_func(ctx, pred, true_fn, false_fn, inputs):
functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn))
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
for branch in [true_fn, false_fn]:
if _has_potential_branch_input_mutation(
branch, unwrapped_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"One of torch.cond branch might be modifying the input! "
"Consider cloning the input before modifying it. "
)
for branch in [true_fn, false_fn]:
if _has_potential_branch_input_alias(
branch, unwrapped_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"One of torch.cond branch might be aliasing the input! "
"If you are returning a view of the input, please make sure "
"to clone it. "
)
check_input_mutation_and_alias(branch, unwrapped_inputs, pre_dispatch=pre_dispatch)

cond_return = cond_op(
unwrapped_pred, functional_true, functional_false, unwrapped_inputs
Expand Down
18 changes: 10 additions & 8 deletions torch/_higher_order_ops/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_has_potential_branch_input_mutation,
check_input_mutation_and_alias,
_maybe_reenter_make_fx,
autograd_not_implemented,
reenter_make_fx,
Expand Down Expand Up @@ -420,13 +420,15 @@ def flex_attention_functionalize(
functional_score_mod = ctx.functionalize(score_mod)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
with TransformGetItemToIndex():
mutates = _has_potential_branch_input_mutation(
score_mod, example_vals, pre_dispatch
)
# The only care about mutations of existing buffers since we can't replay these.
# However, we can just error if anything EED3 is detected
if mutates:
raise UnsupportedAliasMutationException("Mutations detected in score_mod")
check_input_mutation_and_alias(score_mod, example_vals, pre_dispatch)
# mutates = _has_potential_branch_input_mutation(
# score_mod, example_vals, pre_dispatch
# )
# # The only care about mutations of existing buffers since we can't replay these.
# # However, we can just error if anything is detected
# if mutates:
# raise UnsupportedAliasMutationException("Mutations detected in score_mod")


out = flex_attention(
query_unwrapped,
Expand Down
28 changes: 14 additions & 14 deletions torch/_higher_order_ops/hints_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
check_input_mutation_and_alias,
autograd_not_implemented,
reenter_make_fx,
unique_graph_id,
Expand Down Expand Up @@ -97,18 +96,19 @@ def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints):
with ctx.redispatch_to_next():
functional_body_fn = ctx.functionalize(body_fn)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
if _has_potential_branch_input_mutation(
body_fn, unwrapped_args, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"body_fn of hints_wrapper might be modifying the input!"
)
if _has_potential_branch_input_alias(
body_fn, unwrapped_args, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"body_fn of hints_wrapper might be aliasing the input!"
)
check_input_mutation_and_alias(body_fn, unwrapped_args, pre_dispatch=pre_dispatch)
# if _has_potential_branch_input_mutation(
# body_fn, unwrapped_args, pre_dispatch=pre_dispatch
# ):
# raise UnsupportedAliasMutationException(
# "body_fn of hints_wrapper might be modifying the input!"
# )
# if _has_potential_branch_input_alias(
# body_fn, unwrapped_args, pre_dispatch=pre_dispatch
# ):
# raise UnsupportedAliasMutationException(
# "body_fn of hints_wrapper might be aliasing the input!"
# )
outputs = hints_wrapper(
functional_body_fn,
unwrapped_args,
Expand Down
22 changes: 11 additions & 11 deletions torch/_higher_order_ops/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
check_input_mutation_and_alias,
_maybe_run_with_interpreter,
reenter_make_fx,
UnsupportedAliasMutationException,
Expand Down Expand Up @@ -250,15 +249,16 @@ def map_functionalize(ctx, f, xs, pos_args):
with disable_proxy_modes_tracing():
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
if _has_potential_branch_input_mutation(
f, example_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException("torch.map is mutating the input!")

if _has_potential_branch_input_alias(
f, example_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
check_input_mutation_and_alias(f, example_inputs, pre_dispatch=pre_dispatch)
# if _has_potential_branch_input_mutation(
# f, example_inputs, pre_dispatch=pre_dispatch
# ):
# raise UnsupportedAliasMutationException("torch.map is mutating the input!")

# if _has_potential_branch_input_alias(
# f, example_inputs, pre_dispatch=pre_dispatch
# ):
# raise UnsupportedAliasMutationException("torch.map is aliasing the input!")

map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
return ctx.wrap_tensors(map_return)
18 changes: 4 additions & 14 deletions torch/_higher_order_ops/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
check_input_mutation_and_alias,
_set_compilation_env,
autograd_not_implemented,
reenter_make_fx,
Expand Down Expand Up @@ -454,18 +453,9 @@ def scan_functionalize(ctx, combine_fn, init, xs, reverse, additional_inputs):
unwrapped_additional_inputs,
)
)
if _has_potential_branch_input_mutation(
combine_fn, sample_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"Combine_fn might be modifying the input!"
)
if _has_potential_branch_input_alias(
combine_fn, sample_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"Combine_fn might be aliasing the input!"
)

check_input_mutation_and_alias(combine_fn, sample_inputs, pre_dispatch=pre_dispatch)

ret = scan_op(
functional_combine_fn,
unwrapped_init,
Expand Down
Loading
0