8000 Update on "[multigraph] use backend specializations in compile_and_ca… · pytorch/pytorch@3e521df · GitHub
[go: up one dir, main page]

Skip to content

Commit 3e521df

Browse files
committed
Update on "[multigraph] use backend specializations in compile_and_call_fx_graph"
The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes:** 1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes (this PR):** 1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. 3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions: ![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73) [ghstack-poisoned]
2 parents 00abea0 + 72c7b27 commit 3e521df

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

torch/_dynamo/output_graph.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1505,10 +1505,17 @@ def compile_and_call_fx_graph(self, tx, rv, root):
15051505
sources = [a.source for a in self.graphargs]
15061506
for specialization in old_fake_mode.shape_env.backend_specializations:
15071507
source_index = sources.index(specialization.source)
1508+
check_fn_source = inspect.getsource(specialization.check_fn).strip()
15081509
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
15091510
specialization.check_fn,
1510-
[inspect.getsource(specialization.check_fn)],
1511+
[check_fn_source],
15111512
)
1513+
1514+
log.debug(
1515+
"Compiling backend specialized graph with specialization=%s",
1516+
check_fn_source,
1517+
)
1518+
15121519
specialized_compiles.append(
15131520
(
15141521
functools.partial(

0 commit comments

Comments
 (0)
0