-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[multigraph] use backend specializations in compile_and_call_fx_graph #152601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152601
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 New FailuresAs of commit ee88cb9 with merge base 8f54e56 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
…ll_fx_graph" The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes:** 1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes (this PR):** 1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. 3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:  [ghstack-poisoned]
…ll_fx_graph" The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes:** 1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes (this PR):** 1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. 3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:  [ghstack-poisoned]
…ll_fx_graph" The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes:** 1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes (this PR):** 1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. 3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:  [ghstack-poisoned]
…ll_fx_graph" The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes:** 1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes (this PR):** 1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. 3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:  [ghstack-poisoned]
…ll_fx_graph" The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes:** 1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes (this PR):** 1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. 3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:  [ghstack-poisoned]
raise RuntimeError( | ||
"Backend specializations are only supported for contiguous tensors." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's going on here?
for specialization in old_fake_mode.shape_env.backend_specializations: | ||
source_index = sources.index(specialization.source) | ||
check_fn_source = inspect.getsource(specialization.check_fn).strip() | ||
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined] | ||
specialization.check_fn, | ||
[check_fn_source], | ||
) | ||
|
||
log.debug( | ||
"Compiling backend specialized graph with specialization=%s", | ||
check_fn_source, | ||
) | ||
|
||
specialized_compiles.append( | ||
( | ||
functools.partial( | ||
lambda idx, args, check_fn=check_fn: check_fn( | ||
args[idx] | ||
), | ||
source_index, | ||
), | ||
self.call_user_compiler(gm, specialization=specialization), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This calls the backend compiler with the (tensor_args,) for this graph and the specialization argument, right?
I'm not sure this is the right design. The (tensor_args,) don't have the same shape as the specialization -- will that be a problem?
An alternative design is that there is some lazy dispatching layer right after Dynamo but before AOTAutograd. Let's say the user calls the following for the first time:
# A
x = torch.randn(3)
mark_dynamic(x, 0, backend_specializations=[1, 2])
torch.compile(f)(x)
Then this traces out a graph from Dynamo with dynamic shapes.
Then, on future calls to torch.compile:
# B
y = torch.randn(1)
torch.compile(f)(y)
- On seeing a specialized shape for the first time: this skips Dynamo but directly forwards the args (y,) to the backend to compile a graph
# C
z = torch.randn(1)
torch.compile(f)(z)
- On seeing a specialized shape again: this pulls up the graph the backend compiled for said shape.
One way to implement this is:
- Let's think about the Dynamo cache as a mapping from guards to a callable
- After (A), there is a guard for each of the specializations: {"batch_size==1": call_backend_compile(), "batch_size==2": call_backend_compile(), "batch_size==anything_else": compiled_artifact}
- (B) hits the call_backend_compile() function, which will compile a backend function and replace the Dynamo cache entry with {"batch_size==1": compiled_artifact}
- Future hits to this guard (e.g. C) will just hit the compiled artifact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The benefit of the alternative lazy design is that the backend doesn't need to work hard to figure out how to do the specialization: it's almost like calling regular torch.compile again, except it is able to skip Dynamo.
One side effect is that we don't have to impose constraints on the strides (this PR needs to do that because it needs to figure out how to create a FakeTensor, right?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this makes sense. cc @anijain2305 for thoughts as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few details that we need to think about
-
We will have multiple cache entries per code object here. For example, our cache size limit is 8, but the specialization here will require us to raise cache size limit for certain code objects.
-
Dynamo cache as a mapping from guards to a callable - This is true, but there is a subtle difference. Dynamo does guards to bytecode mapping. This bytecode contains the call to the compiled_graph (not Fx graph, a compiled graph). So in this design, we will have to figure out how to (1) stash the bytecode, and (2) stash the Dynamo graph.
-
Overwriting cache entry is also questionable.
Maybe we have the bytecode that calls the backend_compile
. And then the backend_compile
internally checks if there is a compiled code. If yes, then run the compiled code, otherwise run the AOT + Inductor compilation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anijain2305 thoughts on #153449 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the issue with the current implementation? Its not bad. It gives the hierarchical feel, which kind of makes sense in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider the following code:
x = torch.randn(3)
mark_dynamic(x, 0, backend_specializations=[1, 2])
torch.compile(f)(x)
x = torch.randn(1)
torch.compile(f)(x)
x = torch.randn(2)
torch.compile(f)(x)
On the first torch.compile call, we will attempt to compile all of the backend specializations. That torch.compile call only has one set of sample inputs (of shape [3]). The problems I'm worried about is:
a) Compile time will be slow up front. On the first torch.compile it looks like we call the backend compiler three times.
b) Because there are no real tensor inputs of shape [1] and shape [2], we need to guess at those tensors and assume that they're contiguous. This doesn't seem very good
The lazier design (#153449) solves this by (a) deferring compilation of shape [1] and shape [2] until we actually see inputs of those shapes and (b) if the strides change then it's a recompile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree with @zou3519 comments
backend_specializations=[ | ||
(16, lambda x0: x0 == 16), | ||
], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have an api flow for when you want to specify conditions on multiple vars?
E.g.
lambda x, y: x == 1 and y == 1
lambda: x, y: x % 16 and y % 16
You dont necessarily want to specialize on x == 1 and y % 16
, which I assume would fall out of the pairwise specializations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not at the moment. vLLM actually only has one symbolic variable (https://www.anyscale.com/blog/continuous-batching-llm-inference) so we don't need to worry about that for our first customer. That being said, I'm happy to bikeshed what a better multi-var API may look like during composability.
dynamic_specialized = do_bench( | ||
lambda: inductor_matmul(dynamic_specialized_a, b) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check the output code
@@ -1 +1 @@ | |||
14256e6040d9e14698a877924456cdd92bfcd01d | |||
8eeef7f5b5363e9f35576184659226cc082311d6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
intentional?
for specialization in old_fake_mode.shape_env.backend_specializations: | ||
source_index = sources.index(specialization.source) | ||
check_fn_source = inspect.getsource(specialization.check_fn).strip() | ||
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined] | ||
specialization.check_fn, | ||
[check_fn_source], | ||
) | ||
|
||
log.debug( | ||
"Compiling backend specialized graph with specialization=%s", | ||
check_fn_source, | ||
) | ||
|
||
specialized_compiles.append( | ||
( | ||
functools.partial( | ||
lambda idx, args, check_fn=check_fn: check_fn( | ||
args[idx] | ||
), | ||
source_index, | ||
), | ||
self.call_user_compiler(gm, specialization=specialization), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this makes sense. cc @anijain2305 for thoughts as well
|
||
m = 16 | ||
k = 1280 | ||
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, The function is decorated with requires_gpu
which means the case will run on GPUs like cuda/xpu, but the hard code cuda
here will fail on other GPUs like XPU.
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) | |
dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16) |
m = 16 | ||
k = 1280 | ||
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) | ||
dynamic_specialized_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dynamic_specialized_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) | |
dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16) |
k = 1280 | ||
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) | ||
dynamic_specialized_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) | ||
b = torch.randn(k, m, device="cuda", dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
b = torch.randn(k, m, device="cuda", dtype=torch.bfloat16) | |
b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16) |
Abandoning in favor of lazy approach: #153449 |
Stack from ghstack (oldest at bottom):
The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 8000 p>
There's really two parts of this work:
The frontend changes:
backend_specializations
to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.The backend changes (this PR):
/_dynamo/variables/builder.py
call_user_compiler
N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions: