8000 Update on "[multigraph] use specializations in compile_and_call_fx_gr… · pytorch/pytorch@51c25ee · GitHub
[go: up one dir, main page]

Skip to content

Commit 51c25ee

Browse files
committed
Update on "[multigraph] use specializations in compile_and_call_fx_graph"
The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes:** 1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes (this PR):** 1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler. 3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions: ![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73) cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
1 parent bf7c26a commit 51c25ee

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

torch/fx/experimental/symbolic_shapes.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3644,12 +3644,25 @@ def patch_source_specialization(
36443644
sym = self.source_to_var[name]
36453645
expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr
36463646
new_axioms = dict(self.get_implications(self.simplify(expr)))
3647+
added_replacements = {}
3648+
for axiom in new_axioms:
3649+
if (
3650+
isinstance(axiom, sympy.Eq)
3651+
and isinstance(axiom.lhs, sympy.Symbol)
3652+
and isinstance(axiom.rhs, sympy.Integer)
3653+
and axiom.lhs not in self.replacements
3654+
):
3655+
self.replacements[axiom.lhs] = axiom.rhs
3656+
added_replacements[axiom.lhs] = axiom.rhs
3657+
36473658
self.axioms.update(new_axioms)
36483659
try:
36493660
yield
36503661
finally:
36513662
for k in new_axioms:
36523663
self.axioms.pop(k, None)
3664+
for k in added_replacements:
3665+
self.replacements.pop(k, None)
36533666

36543667
def check_equal(self, other: ShapeEnv) -> None:
36553668
"""Compare another ShapeEnv for equivalence"""

0 commit comments

Comments
 (0)
0