-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Fix fake tensor caching when output has unbacked #153034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153034
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2aca85c with merge base 480ae2d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
We handle fake tensor caching in two ways: 1. If the inputs have no symbols (SymInt, etc) then we cache on the FakeTensorMode. 2. If the inputs have symbols then we cache on the ShapeEnv. This way the symbols in the inputs and outputs are associated with the guards in place at the time of the call. However - it's possible to have an op where there are no symbols in the inputs but there is an unbacked symbol in the output. In this case we shouldn't cache at all because what would that really mean? So this PR changes the caching behavior so that if there's a symbol in the output which doesn't come in some way from the input then we refuse to cache that op. Added a test which checks for this case. While in there I also did a couple other related changes: 1. Added negative caching - if we see that an (op, args) failed to cache previously we don't even bother trying to cache it again. 2. Reworked the inner behavior of _cached_dispatch_impl a little to make it more clear which bits we expect to be able to throw _BypassDispatchCache and add some comments. [ghstack-poisoned]
Sounds like this error is flaky |
Merge startedYour change will be merged while ignoring the following 2 checks: pull / linux-focal-py3_9-clang9-xla / build, inductor / unit-test / cuda12.6-py3.10-gcc9-sm86 / test (inductor_cpp_wrapper, 1, 2, ephemeral.linux.g5.4xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot revert -m "Broke pr_time_benchmarks, see https://hud.pytorch.org/hud/pytorch/pytorch/d07fbd41e3589fc9377865a95960b211ec899b90/1?per_page=50&name_filter=pr_time_be&mergeEphemeralLF=true" -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
This reverts commit 4f425a0. Reverted #153034 on behalf of https://github.com/malfet due to Broke pr_time_benchmarks, see https://hud.pytorch.org/hud/pytorch/pytorch/d07fbd41e3589fc9377865a95960b211ec899b90/1?per_page=50&name_filter=pr_time_be&mergeEphemeralLF=true ([comment](#153034 (comment)))
@aorenste your PR has been successfully reverted. |
We handle fake tensor caching in two ways: 1. If the inputs have no symbols (SymInt, etc) then we cache on the FakeTensorMode. 2. If the inputs have symbols then we cache on the ShapeEnv. This way the symbols in the inputs and outputs are associated with the guards in place at the time of the call. However - it's possible to have an op where there are no symbols in the inputs but there is an unbacked symbol in the output. In this case we shouldn't cache at all because what would that really mean? So this PR changes the caching behavior so that if there's a symbol in the output which doesn't come in some way from the input then we refuse to cache that op. Added a test which checks for this case. While in there I also did a couple other related changes: 1. Added negative caching - if we see that an (op, args) failed to cache previously we don't even bother trying to cache it again. 2. Reworked the inner behavior of _cached_dispatch_impl a little to make it more clear which bits we expect to be able to throw _BypassDispatchCache and add some comments. [ghstack-poisoned]
We handle fake tensor caching in two ways: 1. If the inputs have no symbols (SymInt, etc) then we cache on the FakeTensorMode. 2. If the inputs have symbols then we cache on the ShapeEnv. This way the symbols in the inputs and outputs are associated with the guards in place at the time of the call. However - it's possible to have an op where there are no symbols in the inputs but there is an unbacked symbol in the output. In this case we shouldn't cache at all because what would that really mean? So this PR changes the caching behavior so that if there's a symbol in the output which doesn't come in some way from the input then we refuse to cache that op. Added a test which checks for this case. While in there I also did a couple other related changes: 1. Added negative caching - if we see that an (op, args) failed to cache previously we don't even bother trying to cache it again. 2. Reworked the inner behavior of _cached_dispatch_impl a little to make it more clear which bits we expect to be able to throw _BypassDispatchCache and add some comments. [ghstack-poisoned]
Starting merge as part of PR stack under #152662 |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command
Details for Dev Infra teamRaised by workflow job |
We handle fake tensor caching in two ways: 1. If the inputs have no symbols (SymInt, etc) then we cache on the FakeTensorMode. 2. If the inputs have symbols then we cache on the ShapeEnv. This way the symbols in the inputs and outputs are associated with the guards in place at the time of the call. However - it's possible to have an op where there are no symbols in the inputs but there is an unbacked symbol in the output. In this case we shouldn't cache at all because what would that really mean? So this PR changes the caching behavior so that if there's a symbol in the output which doesn't come in some way from the input then we refuse to cache that op. Added a test which checks for this case. While in there I also did a couple other related changes: 1. Added negative caching - if we see that an (op, args) failed to cache previously we don't even bother trying to cache it again. 2. Reworked the inner behavior of _cached_dispatch_impl a little to make it more clear which bits we expect to be able to throw _BypassDispatchCache and add some comments. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
We handle fake tensor caching in two ways:
This way the symbols in the inputs and outputs are associated with the guards in place at the time of the call.
However - it's possible to have an op where there are no symbols in the inputs but there is an unbacked symbol in the output. In this case we shouldn't cache at all because what would that really mean?
So this PR changes the caching behavior so that if there's a symbol in the output which doesn't come in some way from the input then we refuse to cache that op.
Added a test which checks for this case.
While in there I also did a couple other related changes:
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames