8000 [multigraph] add specialize_on kwarg to mark_{dynamic,unbacked} (#153… · pytorch/pytorch@172015f · GitHub
[go: up one dir, main page]

Skip to content

Commit 172015f

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
[multigraph] add specialize_on kwarg to mark_{dynamic,unbacked} (#153433)
The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes (this PR):** 1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes:** 1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler. 3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to 10000 the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions: ![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73) Pull Request resolved: #153433 Approved by: https://github.com/zou3519
1 parent 9371491 commit 172015f

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

torch/_dynamo/decorators.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ class directly; instead, use :func:`mark_dynamic`.
524524

525525

526526
@forbid_in_graph
527-
def mark_unbacked(t, index, strict=False):
527+
def mark_unbacked(t, index, strict=False, specialize_on=None):
528528
"""
529529
Mark a tensor as having an unbacked dim. This changes the semantics of operations,
530530
we will always report the size does not equal zero/one, we will turn asserts
@@ -547,8 +547,13 @@ def mark_unbacked(t, index, strict=False):
547547
t._dynamo_strict_unbacked_indices.add(index)
548548
return
549549

550+
if not hasattr(t, "_specialized_on"):
551+
t._specialize_on = {}
552+
550553
if not hasattr(t, "_dynamo_unbacked_indices"):
551554
t._dynamo_unbacked_indices = set()
555+
556+
t._specialize_on[index] = specialize_on if specialize_on is not None else []
552557
t._dynamo_unbacked_indices.add(index)
553558
return
554559

@@ -558,7 +563,7 @@ def mark_unbacked(t, index, strict=False):
558563

559564

560565
@forbid_in_graph
561-
def mark_dynamic(t, index, *, min=None, max=None):
566+
def mark_dynamic(t, index, *, min=None, max=None, specialize_on=None):
562567
"""
563568
Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim.
564569
@@ -581,6 +586,20 @@ def mark_dynamic(t, index, *, min=None, max=None):
581586
4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made
582587
before torch.compile.
583588
589+
5) If specialize_on is passed in, we will perform a single generic Dynamo trace followed by
590+
multiple specialized compilations in addition to a single generic compilation. NB: For now we only support
591+
per dimension specialization, or in other words we do not generate a cross product of specializations.
592+
At runtime, we will dispatch to a specialized compiled region if the input matches the specialization criteria.
593+
594+
For example:
595+
mark_dynamic(..., specialize_on=[
596+
lambda x: x == 8,
597+
lambda x: x == 16
598+
])
599+
600+
This approach results in one Dynamo trace and two backend compilations. When the input dimension equals 8 or 16
601+
at runtime, execution will be directed to the specialized compiled region. Performance measurements indicate
602+
2-8x speedups depending on the specific specialization and model architecture.
584603
"""
585604
if is_traceable_wrapper_subclass(t):
586605
# default behavior: mirror mark_dynamic() on all inner tensors with same dim as t
@@ -593,14 +612,20 @@ def mark_dynamic(t, index, *, min=None, max=None):
593612
if not hasattr(t, "_dynamo_dynamic_indices"):
594613
t._dynamo_dynamic_indices = set()
595614
t._dynamo_dynamic_range = set()
615+
616+
if not hasattr(t, "_specialize_on"):
617+
t._specialize_on = {}
618+
596619
# TODO(voz): Should we bounds check?
597620
t._dynamo_dynamic_indices.add(index)
598621
t._dynamo_dynamic_range.add(_DimRange(index, min, max))
622+
t._specialize_on[index] = specialize_on if specialize_on is not None else []
599623
return
600624

601625
assert isinstance(index, (list, tuple))
602626
for i in index:
603627
mark_dynamic(t, i, min=min, max=max)
628+
mark_dynamic(t, i, min=min, max=max, specialize_on=specialize_on)
604629

605630

606631
@forbid_in_graph

0 commit comments

Comments
 (0)
0