-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Reland "AOTAutograd: Go down inference path if no outputs require grad (#111011)" #111347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111347
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit af715b3 with merge base e0e15a4 ( UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# If inference mode is enabled during compilation, assume that we don't | ||
# want to match on any training graph patterns | ||
if torch.is_inference_mode_enabled(): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @Chillee. More details in the PR description, but let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... this is a bit weird to me, but I guess it's fine.
In particular, there's no reason an "inference" graph can't also have joint graph patterns apply to it. For example, if we capture forwards and backwards in a single graph, we might go down the inference graph path anyways?
I guess this wouldn't work today anyways, as these passes are hardcoded to trigger in the "joint graph" part. But something to keep in mind in the future.
…require grad (#111011)"" Re-land of #111011. The original PR ended up having a bad interaction with code that tried to run `torch.compile` under `with torch.inference_mode`, which caused some internal tests to fail. The issue was that: (1) AOTInductor invokes the pattern matcher passes in inductor (2) The pattern matcher registers some code with [training_graph](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/pad_mm.py#L461) (3) The `training_graph` function expects to be able to set the global autograd state to `requires_grad`, and always get out a join graph (assertion [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/pattern_matcher.py#L1196)). (4) However, when inference_mode is activated, and you try to run AOTAutograd, AOTAutograd will witness that all outputs to the traced function will not require grad, and (now correctly) think that we are tracing an inference graph, which fails the above assert. After talking to Bin, it sounds like these training-only patterns aren't necessary when we know we are compiling an inference graph (which should always be the case if you're running torch.compile with inference_mode). So I updated the pattern matcher to ignore any pattern matches using `training_graph`, when inference_mode is enabled. This reverts commit cf6b1cd. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
@pytorchbot help |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot --help |
PyTorchBot Help
Merge
Revert
Rebase
Label
Dr CI
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: trunk / linux-focal-rocm5.6-py3.8 / test (default, 1, 3, linux.rocm.gpu, unstable) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
pytorch#111011)" (pytorch#111347) Re-land of pytorch#111011. The original PR ended up having a bad interaction with code that tried to run `torch.compile` under `with torch.inference_mode`, which caused some internal tests to fail. The issue was that: (1) AOTInductor invokes the pattern matcher passes in inductor (2) The pattern matcher registers some code with [training_graph](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/pad_mm.py#L461) (3) The `training_graph` function expects to be able to set the global autograd state to `requires_grad`, and always get out a join graph (assertion [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/pattern_matcher.py#L1196)). (4) However, when inference_mode is activated, and you try to run AOTAutograd, AOTAutograd will witness that all outputs to the traced function will not require grad, and (now correctly) think that we are tracing an inference graph, which fails the above assert. After talking to Bin, it sounds like these training-only patterns aren't necessary when we know we are compiling an inference graph (which should always be the case if you're running torch.compile with inference_mode). So I updated the pattern matcher to ignore any pattern matches using `training_graph`, when inference_mode is enabled. This reverts commit cf6b1cd. Pull Request resolved: pytorch#111347 Approved by: https://github.com/Chillee
pytorch#111011)" (pytorch#111347) Re-land of pytorch#111011. The original PR ended up having a bad interaction with code that tried to run `torch.compile` under `with torch.inference_mode`, which caused some internal tests to fail. The issue was that: (1) AOTInductor invokes the pattern matcher passes in inductor (2) The pattern matcher registers some code with [training_graph](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/pad_mm.py#L461) (3) The `training_graph` function expects to be able to set the global autograd state to `requires_grad`, and always get out a join graph (assertion [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/pattern_matcher.py#L1196)). (4) However, when inference_mode is activated, and you try to run AOTAutograd, AOTAutograd will witness that all outputs to the traced function will not require grad, and (now correctly) think that we are tracing an inference graph, which fails the above assert. After talking to Bin, it sounds like these training-only patterns aren't necessary when we know we are compiling an inference graph (which should always be the case if you're running torch.compile with inference_mode). So I updated the pattern matcher to ignore any pattern matches using `training_graph`, when inference_mode is enabled. This reverts commit cf6b1cd. Pull Request resolved: pytorch#111347 Approved by: https://github.com/Chillee
pytorch#111011)" (pytorch#111347) Re-land of pytorch#111011. The original PR ended up having a bad interaction with code that tried to run `torch.compile` under `with torch.inference_mode`, which caused some internal tests to fail. The issue was that: (1) AOTInductor invokes the pattern matcher passes in inductor (2) The pattern matcher registers some code with [training_graph](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/pad_mm.py#L461) (3) The `training_graph` function expects to be able to set the global autograd state to `requires_grad`, and always get out a join graph (assertion [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/pattern_matcher.py#L1196)). (4) However, when inference_mode is activated, and you try to run AOTAutograd, AOTAutograd will witness that all outputs to the traced function will not require grad, and (now correctly) think that we are tracing an inference graph, which fails the above assert. After talking to Bin, it sounds like these training-only patterns aren't necessary when we know we are compiling an inference graph (which should always be the case if you're running torch.compile with inference_mode). So I updated the pattern matcher to ignore any pattern matches using `training_graph`, when inference_mode is enabled. This reverts commit cf6b1cd. Pull Request resolved: pytorch#111347 Approved by: https://github.com/Chillee
…ires_grad=True tensor out of the graph for inference" The original behavior of torch.compile w.r.t. input mutations maintains that if an input to a graph was mutated, **and** requires grad, we will keep the input mutation outside of the graph and replay it at runtime. This is important because, e.g., an input can have outstanding aliases, and mutating the input in eager mode will cause autograd to change the `grad_fn` of all outstanding aliases. It looks like landing #111347 changed this behavior slightly: * The linked PR makes it possible for AOTAutograd to go down the inference code path, even if some inputs require grad (because all of the outputs of the graph were seen to not require grad) * AOTAutograd's logic in the inference code path today is to **always** keep input mutations in the graph. This PR fixes that regression: regardless of inference vs. training, we should always keep input mutations outside of the graph if the input requires_grad. cc tugsbayasgalan (for review) cc oulgen (related to some of your q's around triton kernel handling) [ghstack-poisoned]
…ensor out of the graph for inference" The original behavior of torch.compile w.r.t. input mutations maintains that if an input to a graph was mutated, **and** requires grad, we will keep the input mutation outside of the graph and replay it at runtime. This is important because, e.g., an input can have outstanding aliases, and mutating the input in eager mode will cause autograd to change the `grad_fn` of all outstanding aliases. It looks like landing #111347 changed this behavior slightly: * The linked PR makes it possible for AOTAutograd to go down the inference code path, even if some inputs require grad (because all of the outputs of the graph were seen to not require grad) * AOTAutograd's logic in the inference code path today is to **always** keep input mutations in the graph. This PR fixes that regression: regardless of inference vs. training, we should always keep input mutations outside of the graph if the input requires_grad. cc tugsbayasgalan (for review) cc oulgen (related to some of your q's around triton kernel handling) [ghstack-poisoned]
…f the graph for inference (#113584) The original behavior of torch.compile w.r.t. input mutations maintains that if an input to a graph was mutated, **and** requires grad, we will keep the input mutation outside of the graph and replay it at runtime. This is important because, e.g., an input can have outstanding aliases, and mutating the input in eager mode will cause autograd to change the `grad_fn` of all outstanding aliases. It looks like landing #111347 changed this behavior slightly: * The linked PR makes it possible for AOTAutograd to go down the inference code path, even if some inputs require grad (because all of the outputs of the graph were seen to not require grad) * AOTAutograd's logic in the inference code path today is to **always** keep input mutations in the graph. This PR fixes that regression: regardless of inference vs. training, we should always keep input mutations outside of the graph if the input requires_grad. Pull Request resolved: #113584 Approved by: https://github.com/tugsbayasgalan ghstack dependencies: #113267, #113416
…raph if they are under no_grad, even if they require_grad" Quick recap of events: (1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later) (2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them (3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for: ``` torch.compile def f(x): x.mul_(2) x = torch.ones(2, requires_grad=True).clone() # x requires_grad, so the input mutation will change some autograd metadata, like the version counter # However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed with torch.no_grad(): f(x) ``` This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad). [ghstack-poisoned]
… under no_grad, even if they require_grad" Quick recap of events: (1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later) (2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them (3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for: ``` torch.compile def f(x): x.mul_(2) x = torch.ones(2, requires_grad=True).clone() # x requires_grad, so the input mutation will change some autograd metadata, like the version counter # However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed with torch.no_grad(): f(x) ``` This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad). [ghstack-poisoned]
…raph if they are under no_grad, even if they require_grad" Quick recap of events: (1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later) (2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them (3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for: ``` torch.compile def f(x): x.mul_(2) x = torch.ones(2, requires_grad=True).clone() # x requires_grad, so the input mutation will change some autograd metadata, like the version counter # However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed with torch.no_grad(): f(x) ``` This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad). [ghstack-poisoned]
… under no_grad, even if they require_grad" Quick recap of events: (1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later) (2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them (3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for: ``` torch.compile def f(x): x.mul_(2) x = torch.ones(2, requires_grad=True).clone() # x requires_grad, so the input mutation will change some autograd metadata, like the version counter # However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed with torch.no_grad(): f(x) ``` This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad). [ghstack-poisoned]
…raph if they are under no_grad, even if they require_grad" Quick recap of events: (1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later) (2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them (3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for: ``` torch.compile def f(x): x.mul_(2) x = torch.ones(2, requires_grad=True).clone() # x requires_grad, so the input mutation will change some autograd metadata, like the version counter # However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed with torch.no_grad(): f(x) ``` This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad). [ghstack-poisoned]
… under no_grad, even if they require_grad" Quick recap of events: (1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later) (2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them (3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for: ``` torch.compile def f(x): x.mul_(2) x = torch.ones(2, requires_grad=True).clone() # x requires_grad, so the input mutation will change some autograd metadata, like the version counter # However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed with torch.no_grad(): f(x) ``` This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad). [ghstack-poisoned]
…rad, even if they require_grad (#114646) Quick recap of events: (1) #111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later) (2) That fix happened here: #113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them (3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for: ``` @torch.compile def f(x): x.mul_(2) x = torch.ones(2, requires_grad=True).clone() # x requires_grad, so the input mutation will change some autograd metadata, like the version counter # However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed with torch.no_grad(): f(x) ``` This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad). Pull Request resolved: #114646 Approved by: https://github.com/zou3519
Re-land of #111011.
The original PR ended up having a bad interaction with code that tried to run
torch.compile
underwith torch.inference_mode
, which caused some internal tests to fail.The issue was that:
(1) AOTInductor invokes the pattern matcher passes in inductor
(2) The pattern matcher registers some code with training_graph
(3) The
training_graph
function expects to be able to set the global autograd state torequires_grad
, and always get out a join graph (assertion here).(4) However, when inference_mode is activated, and you try to run AOTAutograd, AOTAutograd will witness that all outputs to the traced function will not require grad, and (now correctly) think that we are tracing an inference graph, which fails the above assert.
After talking to Bin, it sounds like these training-only patterns aren't necessary when we know we are compiling an inference graph (which should always be the case if you're running torch.compile with inference_mode). So I updated the pattern matcher to ignore any pattern matches using
training_graph
, when inference_mode is enabled.Stack from ghstack (oldest at bottom):
This reverts commit cf6b1cd.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler