8000 [multigraph] use specializations in compile_and_call_fx_graph by bobrenjc93 · Pull Request #153449 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[multigraph] use specializations in compile_and_call_fx_graph #153449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from

Conversation

bobrenjc93
Copy link
Contributor
@bobrenjc93 bobrenjc93 commented May 13, 2025

Stack from ghstack (oldest at bottom):

The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work wa 8000 s inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs.

There's really two parts of this work:

The frontend changes:

  1. we introduce an optional kwarg specialize_on to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

The backend changes (this PR):

  1. We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in /_dynamo/variables/builder.py
  2. After we are done dynamo tracing, we will lazily (more on this later) invoke call_user_compiler up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler.
  3. When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

495269647_576053105510082_9189856138964956774_n

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented May 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153449

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit be5ef29 with merge base ef1d45b (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…aph"

cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request May 13, 2025
@bobrenjc93 bobrenjc93 added the topic: not user facing topic category label May 13, 2025
…aph"

cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request May 13, 2025
Copy link
Contributor
@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approach looks reasonable to me. Do we still need to pass specialization to the backend? It is nicer if we don't need to (to not complicate the backend interface), but I could buy that there's something special we need to do

…aph"

cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request May 13, 2025
@bobrenjc93 bobrenjc93 requested a review from anijain2305 May 13, 2025 23:50
@bobrenjc93 bobrenjc93 marked this pull request as ready for review May 13, 2025 23:50
…aph"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler.
3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. NB: instead of doing all of the specialization compiled up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)


cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
…aph"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler.
3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)


cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@bobrenjc93
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 29, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

…aph"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler.
3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)


cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@bobrenjc93
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

…aph"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler.
3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)


cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@bobrenjc93
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: inductor / unit-test / linux-jammy-cpu-py3.9-gcc11-inductor / test (inductor_avx2, 1, 2, linux.10xlarge.avx2)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 30, 2025
AOTAutogradCache uses FXGraphCache which uses the tracing context to get the ShapeEnv. Although the TracingContext global_context is cleared by the time we get around to reusing it, we don't actually need it. We just need the ShapeEnv in the TracingContext, which isn't cleared at the end of dynamo and does persist. This PR adds the tracing context manager around the specialized compile to ensure our caching infrastructure can get access to the ShapeEnv. A test was also added to prove correctness.

Pull Request resolved: #153526
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
ghstack dependencies: #153433, #153449
nWEIdia pushed a commit to nWEIdia/pytorch that referenced this pull request Jun 2, 2025
…h#153449)

The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs.

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler.
3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with pytorch#152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)

Pull Request resolved: pytorch#153449
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#153433
nWEIdia pushed a commit to nWEIdia/pytorch that referenced this pull request Jun 2, 2025
AOTAutogradCache uses FXGraphCache which uses the tracing context to get the ShapeEnv. Although the TracingContext global_context is cleared by the time we get around to reusing it, we don't actually need it. We just need the ShapeEnv in the TracingContext, which isn't cleared at the end of dynamo and does persist. This PR adds the tracing context manager around the specialized compile to ensure our caching infrastructure can get access to the ShapeEnv. A test was also added to prove correctness.

Pull Request resolved: pytorch#153526
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
ghstack dependencies: pytorch#153433, pytorch#153449
qingyi-yan pushed a commit to qingyi-yan/pytorch that referenced this pull request Jun 3, 2025
…h#153449)

The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs.

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler.
3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with pytorch#152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)

Pull Request resolved: pytorch#153449
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#153433
qingyi-yan pushed a commit to qingyi-yan/pytorch that referenced this pull request Jun 3, 2025
AOTAutogradCache uses FXGraphCache which uses the tracing context to get the ShapeEnv. Although the TracingContext global_context is cleared by the time we get around to reusing it, we don't actually need it. We just need the ShapeEnv in the TracingContext, which isn't cleared at the end of dynamo and does persist. This PR adds the tracing context manager around the specialized compile to ensure our caching infrastructure can get access to the ShapeEnv. A test was also added to prove correctness.

Pull Request resolved: pytorch#153526
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
ghstack dependencies: pytorch#153433, pytorch#153449
iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
…h#153449)

The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs.

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler.
3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with pytorch#152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)

Pull Request resolved: pytorch#153449
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#153433
iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
AOTAutogradCache uses FXGraphCache which uses the tracing context to get the ShapeEnv. Although the TracingContext global_context is cleared by the time we get around to reusing it, we don't actually need it. We just need the ShapeEnv in the TracingContext, which isn't cleared at the end of dynamo and does persist. This PR adds the tracing context manager around the specialized compile to ensure our caching infrastructure can get access to the ShapeEnv. A test was also added to prove correctness.

Pull Request resolved: pytorch#153526
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
ghstack dependencies: pytorch#153433, pytorch#153449
@github-actions github-actions bot deleted the gh/bobrenjc93/343/head branch July 2, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0