8000 Reland "AOTAutograd: Go down inference path if no outputs require grad (#111011)" by bdhirsh · Pull Request #111347 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Reland "AOTAutograd: Go down inference path if no outputs require grad (#111011)" #111347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

bdhirsh
Copy link
Contributor
@bdhirsh bdhirsh commented Oct 16, 2023

Re-land of #111011.

The original PR ended up having a bad interaction with code that tried to run torch.compile under with torch.inference_mode, which caused some internal tests to fail.

The issue was that:

(1) AOTInductor invokes the pattern matcher passes in inductor

(2) The pattern matcher registers some code with training_graph

(3) The training_graph function expects to be able to set the global autograd state to requires_grad, and always get out a join graph (assertion here).

(4) However, when inference_mode is activated, and you try to run AOTAutograd, AOTAutograd will witness that all outputs to the traced function will not require grad, and (now correctly) think that we are tracing an inference graph, which fails the above assert.

After talking to Bin, it sounds like these training-only patterns aren't necessary when we know we are compiling an inference graph (which should always be the case if you're running torch.compile with inference_mode). So I updated the pattern matcher to ignore any pattern matches using training_graph, when inference_mode is enabled.

Stack from ghstack (oldest at bottom):

This reverts commit cf6b1cd.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@pytorch-bot
Copy link
pytorch-bot bot commented Oct 16, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111347

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit af715b3 with merge base e0e15a4 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

# If inference mode is enabled during compilation, assume that we don't
# want to match on any training graph patterns
if torch.is_inference_mode_enabled():
return False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Chillee. More details in the PR description, but let me know what you think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... this is a bit weird to me, but I guess it's fine.

In particular, there's no reason an "inference" graph can't also have joint graph patterns apply to it. For example, if we capture forwards and backwards in a single graph, we might go down the inference graph path anyways?

I guess this wouldn't work today anyways, as these passes are hardcoded to trigger in the "joint graph" part. But something to keep in mind in the future.

@bdhirsh bdhirsh added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 16, 2023
…require grad (#111011)""

Re-land of #111011.

The original PR ended up having a bad interaction with code that tried to run `torch.compile` under `with torch.inference_mode`, which caused some internal tests to fail.

The issue was that:

(1) AOTInductor invokes the pattern matcher passes in inductor

(2) The pattern matcher registers some code with [training_graph](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/pad_mm.py#L461)

(3) The `training_graph` function expects to be able to set the global autograd state to `requires_grad`, and always get out a join graph (assertion [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/pattern_matcher.py#L1196)).

(4) However, when inference_mode is activated, and you try to run AOTAutograd, AOTAutograd will witness that all outputs to the traced function will not require grad, and (now correctly) think that we are tracing an inference graph, which fails the above assert.

After talking to Bin, it sounds like these training-only patterns aren't necessary when we know we are compiling an inference graph (which should always be the case if you're running torch.compile with inference_mode). So I updated the pattern matcher to ignore any pattern matches using `training_graph`, when inference_mode is enabled.


This reverts commit cf6b1cd.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 16, 2023
#111011)"

This reverts commit cf6b1cd.

ghstack-source-id: e66471c
Pull Request resolved: #111347
@bdhirsh
Copy link
Contributor Author
bdhirsh commented Oct 17, 2023

@pytorchbot help

@pytorch-bot
Copy link
pytorch-bot bot commented Oct 17, 2023

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'help' (choose from 'merge', 'revert', 'rebase', 'label', 'drci')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

Try @pytorchbot --help for more info.

@bdhirsh
Copy link
Contributor Author
bdhirsh commented Oct 17, 2023

@pytorchbot --help

@pytorch-bot
Copy link
pytorch-bot bot commented Oct 17, 2023

PyTorchBot Help

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

In order to invoke the bot on your PR, include a line that starts with
@pytorchbot anywhere in a comment. That line will form the command; no
multi-line commands are allowed. 

Example:
    Some extra context, blah blah, wow this PR looks awesome

    @pytorchbot merge

optional arguments:
  -h, --help            Show this help message and exit.

command:
  {merge,revert,rebase,label,drci}
    merge               Merge a PR
    revert              Revert a PR
    rebase              Rebase a PR
    label               Add label to a PR
    drci                Update Dr. CI

Merge

usage: @pytorchbot merge [-f MESSAGE | -i] [-ic] [-r [{viable/strict,main}]]

Merge an accepted PR, subject to the rules in .github/merge_rules.json.
By default, this will wait for all required checks (lint, pull) to succeed before merging.

optional arguments:
  -f MESSAGE, --force MESSAGE
                        Merge without checking anything. This requires a reason for auditting purpose, for example:
                        @pytorchbot merge -f 'Minor update to fix lint. Expecting all PR tests to pass'
                        
                        Please use `-f` as last resort, prefer `--ignore-current` to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.
  -i, --ignore-current  Merge while ignoring the currently failing jobs.  Behaves like -f if there are no pending jobs.
  -ic                   Old flag for --ignore-current. Deprecated in favor of -i.
  -r [{viable/strict,main}], --rebase [{viable/strict,main}]
                        Rebase the PR to re run checks before merging.  Accepts viable/strict or main as branch options and will default to viable/strict if not specified.

Revert

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Revert a merged PR. This requires that you are a Meta employee.

Example:
  @pytorchbot revert -m="This is breaking tests on trunk. hud.pytorch.org/" -c=nosignal

optional arguments:
  -m MESSAGE, --message MESSAGE
                        The reason you are reverting, will be put in the commit message. Must be longer than 3 words.
  -c {nosignal,ignoredsignal,landrace,weird,ghfirst}, --classification {nosignal,ignoredsignal,landrace,weird,ghfirst}
                        A machine-friendly classification of the revert reason.

Rebase

usage: @pytorchbot rebase [-s | -b BRANCH]

Rebase a PR. Rebasing defaults to the stable viable/strict branch of pytorch.
Repeat contributor may use this command to rebase their PR.

optional arguments:
  -s, --stable          [DEPRECATED] Rebase onto viable/strict
  -b BRANCH, --branch BRANCH
                        Branch you would like to rebase to

Label

usage: @pytorchbot label labels [labels ...]

Adds label to a PR

positional arguments:
  labels  Labels to add to given Pull Request

Dr CI

usage: @pytorchbot drci 

Update Dr. CI. Updates the Dr. CI comment on the PR in case it's gotten out of sync with actual CI results.

@bdhirsh
Copy link
Contributor Author
bdhirsh commented Oct 17, 2023

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / linux-focal-rocm5.6-py3.8 / test (default, 1, 3, linux.rocm.gpu, unstable)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@bdhirsh bdhirsh added this to the 2.1.1 milestone Oct 18, 2023
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/471/head branch October 20, 2023 14:24
atalman pushed a commit to atalman/pytorch that referenced this pull request Oct 26, 2023
pytorch#111011)" (pytorch#111347)

Re-land of pytorch#111011.

The original PR ended up having a bad interaction with code that tried to run `torch.compile` under `with torch.inference_mode`, which caused some internal tests to fail.

The issue was that:

(1) AOTInductor invokes the pattern matcher passes in inductor

(2) The pattern matcher registers some code with [training_graph](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/pad_mm.py#L461)

(3) The `training_graph` function expects to be able to set the global autograd state to `requires_grad`, and always get out a join graph (assertion [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/pattern_matcher.py#L1196)).

(4) However, when inference_mode is activated, and you try to run AOTAutograd, AOTAutograd will witness that all outputs to the traced function will not require grad, and (now correctly) think that we are tracing an inference graph, which fails the above assert.

After talking to Bin, it sounds like these training-only patterns aren't necessary when we know we are compiling an inference graph (which should always be the case if you're running torch.compile with inference_mode). So I updated the pattern matcher to ignore any pattern matches using `training_graph`, when inference_mode is enabled.

This reverts commit cf6b1cd.

Pull Request resolved: pytorch#111347
Approved by: https://github.com/Chillee
atalman pushed a commit to atalman/pytorch that referenced this pull request Oct 26, 2023
pytorch#111011)" (pytorch#111347)

Re-land of pytorch#111011.

The original PR ended up having a bad interaction with code that tried to run `torch.compile` under `with torch.inference_mode`, which caused some internal tests to fail.

The issue was that:

(1) AOTInductor invokes the pattern matcher passes in inductor

(2) The pattern matcher registers some code with [training_graph](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/pad_mm.py#L461)

(3) The `training_graph` function expects to be able to set the global autograd state to `requires_grad`, and always get out a join graph (assertion [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/pattern_matcher.py#L1196)).

(4) However, when inference_mode is activated, and you try to run AOTAutograd, AOTAutograd will witness that all outputs to the traced function will not require grad, and (now correctly) think that we are tracing an inference graph, which fails the above assert.

After talking to Bin, it sounds like these training-only patterns aren't necessary when we know we are compiling an inference graph (which should always be the case if you're running torch.compile with inference_mode). So I updated the pattern matcher to ignore any pattern matches using `training_graph`, when inference_mode is enabled.

This reverts commit cf6b1cd.

Pull Request resolved: pytorch#111347
Approved by: https://github.com/Chillee
pytorchmergebot pushed a commit to atalman/pytorch that referenced this pull request Nov 2, 2023
pytorch#111011)" (pytorch#111347)

Re-land of pytorch#111011.

The original PR ended up having a bad interaction with code that tried to run `torch.compile` under `with torch.inference_mode`, which caused some internal tests to fail.

The issue was that:

(1) AOTInductor invokes the pattern matcher passes in inductor

(2) The pattern matcher registers some code with [training_graph](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/pad_mm.py#L461)

(3) The `training_graph` function expects to be able to set the global autograd state to `requires_grad`, and always get out a join graph (assertion [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/pattern_matcher.py#L1196)).

(4) However, when inference_mode is activated, and you try to run AOTAutograd, AOTAutograd will witness that all outputs to the traced function will not require grad, and (now correctly) think that we are tracing an inference graph, which fails the above assert.

After talking to Bin, it sounds like these training-only patterns aren't necessary when we know we are compiling an inference graph (which should always be the case if you're running torch.compile with inference_mode). So I updated the pattern matcher to ignore any pattern matches using `training_graph`, when inference_mode is enabled.

This reverts commit cf6b1cd.

Pull Request resolved: pytorch#111347
Approved by: https://github.com/Chillee
@atalman atalman modified the milestones: 2.1.1, 2.1.2 Nov 8, 2023
bdhirsh added a commit that referenced this pull request Nov 14, 2023
…ires_grad=True tensor out of the graph for inference"

The original behavior of torch.compile w.r.t. input mutations maintains that if an input to a graph was mutated, **and** requires grad, we will keep the input mutation outside of the graph and replay it at runtime.

This is important because, e.g., an input can have outstanding aliases, and mutating the input in eager mode will cause autograd to change the `grad_fn` of all outstanding aliases.

It looks like landing #111347 changed this behavior slightly:
* The linked PR makes it possible for AOTAutograd to go down the inference code path, even if some inputs require grad (because all of the outputs of the graph were seen to not require grad)
* AOTAutograd's logic in the inference code path today is to **always** keep input mutations in the graph.

This PR fixes that regression: regardless of inference vs. training, we should always keep input mutations outside of the graph if the input requires_grad.

cc tugsbayasgalan (for review)

cc oulgen (related to some of your q's around triton kernel handling)




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Nov 14, 2023
…ensor out of the graph for inference"

The original behavior of torch.compile w.r.t. input mutations maintains that if an input to a graph was mutated, **and** requires grad, we will keep the input mutation outside of the graph and replay it at runtime.

This is important because, e.g., an input can have outstanding aliases, and mutating the input in eager mode will cause autograd to change the `grad_fn` of all outstanding aliases.

It looks like landing #111347 changed this behavior slightly:
* The linked PR makes it possible for AOTAutograd to go down the inference code path, even if some inputs require grad (because all of the outputs of the graph were seen to not require grad)
* AOTAutograd's logic in the inference code path today is to **always** keep input mutations in the graph.

This PR fixes that regression: regardless of inference vs. training, we should always keep input mutations outside of the graph if the input requires_grad.

cc tugsbayasgalan (for review)

cc oulgen (related to some of your q's around triton kernel handling)




[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Nov 15, 2023
…f the graph for inference (#113584)

The original behavior of torch.compile w.r.t. input mutations maintains that if an input to a graph was mutated, **and** requires grad, we will keep the input mutation outside of the graph and replay it at runtime.

This is important because, e.g., an input can have outstanding aliases, and mutating the input in eager mode will cause autograd to change the `grad_fn` of all outstanding aliases.

It looks like landing #111347 changed this behavior slightly:
* The linked PR makes it possible for AOTAutograd to go down the inference code path, even if some inputs require grad (because all of the outputs of the graph were seen to not require grad)
* AOTAutograd's logic in the inference code path today is to **always** keep input mutations in the graph.

This PR fixes that regression: regardless of inference vs. training, we should always keep input mutations outside of the graph if the input requires_grad.

Pull Request resolved: #113584
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #113267, #113416
@atalman atalman removed this from the 2.1.2 milestone Nov 24, 2023
bdhirsh added a commit that referenced this pull request Nov 28, 2023
…raph if they are under no_grad, even if they require_grad"

Quick recap of events:

(1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later)

(2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them

(3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for:

```
torch.compile
def f(x):
    x.mul_(2)
    
x = torch.ones(2, requires_grad=True).clone()
# x requires_grad, so the input mutation will change some autograd metadata, like the version counter
# However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed
with torch.no_grad():
    f(x)
```

This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad).




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Nov 28, 2023
… under no_grad, even if they require_grad"

Quick recap of events:

(1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later)

(2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them

(3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for:

```
torch.compile
def f(x):
    x.mul_(2)
    
x = torch.ones(2, requires_grad=True).clone()
# x requires_grad, so the input mutation will change some autograd metadata, like the version counter
# However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed
with torch.no_grad():
    f(x)
```

This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad).




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Nov 29, 2023
…raph if they are under no_grad, even if they require_grad"

Quick recap of events:

(1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later)

(2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them

(3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for:

```
torch.compile
def f(x):
    x.mul_(2)
    
x = torch.ones(2, requires_grad=True).clone()
# x requires_grad, so the input mutation will change some autograd metadata, like the version counter
# However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed
with torch.no_grad():
    f(x)
```

This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad).




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Nov 29, 2023
… under no_grad, even if they require_grad"

Quick recap of events:

(1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later)

(2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them

(3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for:

```
torch.compile
def f(x):
    x.mul_(2)
    
x = torch.ones(2, requires_grad=True).clone()
# x requires_grad, so the input mutation will change some autograd metadata, like the version counter
# However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed
with torch.no_grad():
    f(x)
```

This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad).




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Nov 29, 2023
…raph if they are under no_grad, even if they require_grad"

Quick recap of events:

(1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later)

(2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them

(3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for:

```
torch.compile
def f(x):
    x.mul_(2)
    
x = torch.ones(2, requires_grad=True).clone()
# x requires_grad, so the input mutation will change some autograd metadata, like the version counter
# However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed
with torch.no_grad():
    f(x)
```

This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad).




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Nov 29, 2023
… under no_grad, even if they require_grad"

Quick recap of events:

(1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later)

(2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them

(3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for:

```
torch.compile
def f(x):
    x.mul_(2)
    
x = torch.ones(2, requires_grad=True).clone()
# x requires_grad, so the input mutation will change some autograd metadata, like the version counter
# However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed
with torch.no_grad():
    f(x)
```

This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad).




[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Nov 29, 2023
…rad, even if they require_grad (#114646)

Quick recap of events:

(1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later)

(2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them

(3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for:

```
@torch.compile
def f(x):
    x.mul_(2)

x = torch.ones(2, requires_grad=True).clone()
# x requires_grad, so the input mutation will change some autograd metadata, like the version counter
# However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed
with torch.no_grad():
    f(x)
```

This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad).

Pull Request resolved: #114646
Approved by: https://github.com/zou3519
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0