-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[HOP] Mutation and alias rework #146658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[HOP] Mutation and alias rework #146658
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146658
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 55b0301 with merge base 084c4aa ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
…loop, cond and map
@ydwu4 I reworked the mutation and alias checks. I moved the checks into dynamo for scan, associative_scan, while_loop and cond. For map I also included the new check, but since it does not yet use backend='eager', I did not move the check to dynamo. Pleas let me know what you think |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. Left a few minor comments.
# This case is for SymInts and other non-Tensor elements | ||
inputs_fake.append(val) | ||
else: | ||
# This case is for ints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can assert they're ints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
I was wondering though in general whether we could somehow improve this? Are you fine with we currently collect the fake inputs?
_maybe_reenter_make_fx, | ||
autograd_not_implemented, | ||
# check_input_mutation, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I updated the PR and now reworked pretty much all HOPs
torch/_higher_order_ops/scan.py
Outdated
@@ -442,27 +439,6 @@ def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs): | |||
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) | |||
with ctx.redispatch_to_next(): | |||
functional_combine_fn = ctx.functionalize(combine_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might still want to keep the check in functionalization key. In case someone is using the hop directly, which could bypass the dynamo checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, as discussed offline, I reintroduced the checks into the functionalization
key as well.
If the HOP uses the backend='eager', we now we have it in functionalization
key and in dynamo
.
torch/_higher_order_ops/utils.py
Outdated
inp_out_alias_map, | ||
out_out_alias_map) | ||
|
||
def has_potential_input_mutation_or_alias(gm, inputs, pre_dispatch=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Probably we should still name this as has_potential_input_alias_or_mutation? The name change seems unnecessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I corrected it.
In fact, I revised the name of the inner helper potential_input_alias_or_mutation
as well and also as a nit adjusted the return arguments. Now, as the name suggests, the aliases are returned first, followed by the mutations.
# TODO: This is an unexpected behavior for cond | ||
# Without this additional multiplication, | ||
# the output of the backward graph would alias the | ||
# inputs, as the gradients are just 1s and thus get optimized | ||
def true_fn(x): | ||
return x["t"][0] + x["t"][1]["b"] * x["t"][2][0] | ||
return (x["t"][0] * 2.0) + x["t"][1]["b"] * x["t"][2][0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be a bit unexpected for to the user. Currently we don't allow aliases, including input-output aliases. This is problematic, because the gradients could be just 1s, in which case the gradients (arguments) from the upstream are just passed along, which then triggers the alias checks.
A naive solution would be to disable the input-output alias check, but I am not sure whether this causes problems?
Is there another solution to this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A better way is to properly support it through auto_functionalized. This is still WIP though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests are failing. Added a few comments
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict() | ||
|
||
for node in self.graph.nodes: | ||
if node.op == "placeholder": | ||
example_value = node.meta["example_value"] | ||
example_value = _collect_fake_inputs([node])[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens here? I don't expect it will error, will it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it does error out. The issue is that in some testcases. For example in, the issue is that the example value is a BatchedTensorImpl
, for which self.untyped_storage()
doesn't exist.
@@ -1336,6 +1340,8 @@ def create_unbacked_sym_node_var(tx) -> SymNodeVariable: | |||
source_target=self.value, | |||
set_subgraph_inputs="flatten_manual", | |||
should_flatten_outputs=True, | |||
supports_input_mutation=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can put these two as class field? like BaseHOPVariable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -699,17 +741,19 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]) | |||
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}" | |||
|
|||
|
|||
# TODO: Return a more detailed information as to which node | |||
# causes a mutation or an alias. This may requires a per operator tensor version checking | |||
def check_input_alias_and_mutation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we need to move mutated_inputs to the end?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I just moved it in order to be consistent with the function name. It has alias and mutation and that's why I moved it towards the end. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change looks good. Not sure why the test starts to fail.
I think I found the issue. There was one occasion in the inductor where I missed the rearrangement of the return variable from alias_mutation. Fingers crossed for this time. |
This PR reworks the way the input mutations and various aliases are checked
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @ydwu4