8000 [HOP] Mutation and alias rework by bohnstingl · Pull Request #146658 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[HOP] Mutation and alias rework #146658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 66 commits into
base: main
Choose a base branch
from

Conversation

bohnstingl
Copy link
Collaborator
@bohnstingl bohnstingl commented Feb 7, 2025

Copy link
pytorch-bot bot commented Feb 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146658

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 55b0301 with merge base 084c4aa (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@bohnstingl
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Feb 7, 2025
@ydwu4 ydwu4 self-requested a review February 7, 2025 18:10
@bohnstingl
Copy link
Collaborator Author

@ydwu4 I reworked the mutation and alias checks. I moved the checks into dynamo for scan, associative_scan, while_loop and cond. For map I also included the new check, but since it does not yet use backend='eager', I did not move the check to dynamo. Pleas let me know what you think

@bohnstingl bohnstingl requested a review from ydwu4 March 2, 2025 23:02
Copy link
Contributor
@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall. Left a few minor comments.

# This case is for SymInts and other non-Tensor elements
inputs_fake.append(val)
else:
# This case is for ints
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can assert they're ints?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

I was wondering though in general whether we could somehow improve this? Are you fine with we currently collect the fake inputs?

_maybe_reenter_make_fx,
autograd_not_implemented,
# check_input_mutation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I updated the PR and now reworked pretty much all HOPs

@@ -442,27 +439,6 @@ def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
with ctx.redispatch_to_next():
functional_combine_fn = ctx.functionalize(combine_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might still want to keep the check in functionalization key. In case someone is using the hop directly, which could bypass the dynamo checks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, as discussed offline, I reintroduced the checks into the functionalization key as well.
If the HOP uses the backend='eager', we now we have it in functionalization key and in dynamo.

inp_out_alias_map,
out_out_alias_map)

def has_potential_input_mutation_or_alias(gm, inputs, pre_dispatch=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Probably we should still name this as has_potential_input_alias_or_mutation? The name change seems unnecessary

Copy link
Collaborator Author
@bohnstingl bohnstingl Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I corrected it.
In fact, I revised the name of the inner helper potential_input_alias_or_mutation as well and also as a nit adjusted the return arguments. Now, as the name suggests, the aliases are returned first, followed by the mutations.

@bohnstingl bohnstingl marked this pull request as ready for review March 5, 2025 18:13
@bohnstingl bohnstingl requested a review from zou3519 as a code owner March 5, 2025 18:13
Comment on lines +891 to +896
# TODO: This is an unexpected behavior for cond
# Without this additional multiplication,
# the output of the backward graph would alias the
# inputs, as the gradients are just 1s and thus get optimized
def true_fn(x):
return x["t"][0] + x["t"][1]["b"] * x["t"][2][0]
return (x["t"][0] * 2.0) + x["t"][1]["b"] * x["t"][2][0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be a bit unexpected for to the user. Currently we don't allow aliases, including input-output aliases. This is problematic, because the gradients could be just 1s, in which case the gradients (arguments) from the upstream are just passed along, which then triggers the alias checks.

A naive solution would be to disable the input-output alias check, but I am not sure whether this causes problems?
Is there another solution to this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better way is to properly support it through auto_functionalized. This is still WIP though.

Copy link
Contributor
@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are failing. Added a few comments

input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()

for node in self.graph.nodes:
if node.op == "placeholder":
example_value = node.meta["example_value"]
example_value = _collect_fake_inputs([node])[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens here? I don't expect it will error, will it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it does error out. The issue is that in some testcases. For example in, the issue is that the example value is a BatchedTensorImpl, for which self.untyped_storage() doesn't exist.

@@ -1336,6 +1340,8 @@ def create_unbacked_sym_node_var(tx) -> SymNodeVariable:
source_target=self.value,
set_subgraph_inputs="flatten_manual",
should_flatten_outputs=True,
supports_input_mutation=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can put these two as class field? like BaseHOPVariable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -699,17 +741,19 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]])
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"


# TODO: Return a more detailed information as to which node
# causes a mutation or an alias. This may requires a per operator tensor version checking
def check_input_alias_and_mutation(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need to move mutated_inputs to the end?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I just moved it in order to be consistent with the function name. It has alias and mutation and that's why I moved it towards the end. WDYT?

@bohnstingl bohnstingl requested review from ezyang and Chillee as code owners May 15, 2025 21:20
Copy link
Contributor
@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change looks good. Not sure why the test starts to fail.

@bohnstingl
Copy link
Collaborator Author

I think I found the issue. There was one occasion in the inductor where I missed the rearrangement of the return variable from alias_mutation. Fingers crossed for this time.

@bohnstingl bohnstingl requested a review from ydwu4 May 16, 2025 23:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0