8000 [multigraph] add backend_specialization kwarg to mark_dynamic by bobrenjc93 · Pull Request #152597 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[multigraph] add backend_specialization kwarg to mark_dynamic #152597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from

Conversation

bobrenjc93
Copy link
Contributor
@bobrenjc93 bobrenjc93 commented May 1, 2025

Stack from ghstack (oldest at bottom):

The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs.

There's really two parts of this work:

The frontend changes (this PR):

  1. we introduce an optional kwarg backend_specializations to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

The backend changes:

  1. We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in /_dynamo/variables/builder.py
  2. After we are done dynamo tracing, we invoke call_user_compiler N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.
  3. When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

495269647_576053105510082_9189856138964956774_n

Copy link
pytorch-bot bot commented May 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152597

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 7952fa0 with merge base 8f54e56 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bobrenjc93 added 3 commits May 1, 2025 23:36
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
This was referenced May 3, 2025
bobrenjc93 added 3 commits May 3, 2025 19:34
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@bobrenjc93 bobrenjc93 added the topic: not user facing topic category label May 4, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@bobrenjc93 bobrenjc93 changed the title add backend_specialization kwarg to mark_dynamic [multigraph] add backend_specialization kwarg to mark_dynamic May 4, 2025
bobrenjc93 added 3 commits May 4, 2025 19:11
…mic"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
…mic"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes (this PR):**
1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes:**
1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.
3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)


[ghstack-poisoned]
…mic"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes (this PR):**
1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes:**
1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.
3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)


[ghstack-poisoned]
@bobrenjc93 bobrenjc93 requested review from zou3519 and eellison May 5, 2025 12:46
@bobrenjc93 bobrenjc93 marked this pull request as ready for review May 5, 2025 12:46
@@ -552,7 +552,7 @@ def mark_unbacked(t, index, strict=False):


@forbid_in_graph
def mark_dynamic(t, index, *, min=None, max=None):
def mark_dynamic(t, index, *, min=None, max=None, backend_specializations=None):
Copy link
Contributor
@zou3519 zou3519 May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Can we do this for mark_unbacked as well? I think vLLM might actually want to use mark_unbacked
  2. We should probably error somehow if the range on the symint ends up not including the backend_specialization

Comment on lines +578 to +586
5) If backend specializations is passed in, we will perform a single generic Dynamo trace followed by
multiple specialized compilations in addition to a single generic compilation. At runtime, we will dispatch
to a specialized compiled region if the input matches the specialization criteria. For example:

mark_dynamic(..., backend_specializations=[(
16, # hint value
lambda x: x == 16 # specialization predicate
)])

Copy link
Contributor
@zou3519 zou3519 May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backend_specializations is a bit wordy. I don't have a better idea though.

mark_dynamic(..., specialize_on=[
   16,
   lambda x: x == 16,
])

?

Copy link
Contributor
@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API seems reasonable. My judgment might change after I read the next PR

@@ -575,6 +575,18 @@ def mark_dynamic(t, index, *, min=None, max=None):
4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made
before torch.compile.

5) If backend specializations is passed in, we will perform a single generic Dynamo trace followed by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do multiple different mark'ed dynamic dims with backend specializations produce a cross product of specialized graphs?

@bobrenjc93
Copy link
Contributor Author

Abandoning in favor of lazy approach: #153449

@bobrenjc93 bobrenjc93 closed this May 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0