-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[multigraph] add backend_specialization kwarg to mark_dynamic #152597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152597
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 7952fa0 with merge base 8f54e56 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…mic" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…mic" The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes (this PR):** 1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes:** 1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. 3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:  [ghstack-poisoned]
…mic" The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes (this PR):** 1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes:** 1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. 3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:  [ghstack-poisoned]
@@ -552,7 +552,7 @@ def mark_unbacked(t, index, strict=False): | |||
|
|||
|
|||
@forbid_in_graph | |||
def mark_dynamic(t, index, *, min=None, max=None): | |||
def mark_dynamic(t, index, *, min=None, max=None, backend_specializations=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Can we do this for mark_unbacked as well? I think vLLM might actually want to use mark_unbacked
- We should probably error somehow if the range on the symint ends up not including the backend_specialization
5) If backend specializations is passed in, we will perform a single generic Dynamo trace followed by | ||
multiple specialized compilations in addition to a single generic compilation. At runtime, we will dispatch | ||
to a specialized compiled region if the input matches the specialization criteria. For example: | ||
|
||
mark_dynamic(..., backend_specializations=[( | ||
16, # hint value | ||
lambda x: x == 16 # specialization predicate | ||
)]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backend_specializations is a bit wordy. I don't have a better idea though.
mark_dynamic(..., specialize_on=[
16,
lambda x: x == 16,
])
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API seems reasonable. My judgment might change after I read the next PR
@@ -575,6 +575,18 @@ def mark_dynamic(t, index, *, min=None, max=None): | |||
4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made | |||
before torch.compile. | |||
|
|||
5) If backend specializations is passed in, we will perform a single generic Dynamo trace followed by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do multiple different mark'ed dynamic dims with backend specializations produce a cross product of specialized graphs?
Abandoning in favor of lazy approach: #153449 |
Stack from ghstack (oldest at bottom):
The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs.
There's really two parts of this work:
The frontend changes (this PR):
backend_specializations
to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.The backend changes:
/_dynamo/variables/builder.py
call_user_compiler
N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions: