10000 Update on "[multigraph] use specializations in compile_and_call_fx_gr… · pytorch/pytorch@96c5fae · GitHub
[go: up one dir, main page]

Skip to content

Commit 96c5fae

Browse files
committed
Update on "[multigraph] use specializations in compile_and_call_fx_graph"
The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes:** 1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes (this PR):** 1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler. 3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions: ![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73) cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
1 parent fa363b2 commit 96c5fae

File tree

3 files changed

+24
-75
lines changed

3 files changed

+24
-75
lines changed

test/dynamo/test_aot_autograd_cache.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -314,49 +314,6 @@ def fn(x, y):
314314
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
315315
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
316316

317-
@inductor_config.patch("fx_graph_remote_cache", False)
318-
@inductor_config.patch("fx_graph_cache", True)
319-
@functorch_config.patch({"enable_autograd_cache": True})
320-
def test_multi_graph_specialization(self):
321-
"""
322-
Verify multi graph specializations all cache hit
323-
"""
324-
325-
def fn(x):
326-
return x * 5
327-
328-
a = torch.randn(5)
329-
a8 = torch.randn(8)
330-
a16 = torch.randn(16)
331-
torch._dynamo.mark_dynamic(
332-
a,
333-
0,
334-
specialize_on=[
335-
lambda x: x == 8,
336-
lambda x: x == 16,
337-
],
338-
)
339-
340-
compiled_fn = torch.compile(fn, backend="inductor")
341-
342-
# A first call should miss in the cache.
343-
compiled_fn(a)
344-
compiled_fn(a8)
345-
compiled_fn(a16)
346-
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 3)
347-
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
348-
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 3)
349-
350-
self._clear_dynamo_and_codecache()
351-
352-
# A second call should hit on all 3 graphs
353-
compiled_fn(a)
354-
compiled_fn(a8)
355-
compiled_fn(a16)
356-
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 3)
357-
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 3)
358-
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 3)
359-
360317
@inductor_config.patch("fx_graph_remote_cache", False)
361318
@inductor_config.patch("fx_graph_cache", True)
362319
@functorch_config.patch({"enable_autograd_cache": True})

torch/_dynamo/output_graph.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
import sys
3434
import traceback
3535
import weakref
36-
from dataclasses import dataclass, replace
36+
from collections.abc import Sequence
37+
from dataclasses import dataclass
3738
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
3839

3940
import sympy
@@ -51,7 +52,6 @@
5152
CompileId,
5253
GlobalContextCheckpointState,
5354
Source,
54-
tracing,
5555
TracingContext,
5656
)
5757
from torch._subclasses.fake_tensor import FakeTensor
@@ -1502,7 +1502,7 @@ def compile_and_call_fx_graph(self, tx, rv, root):
15021502
self.tracing_context.fake_mode = backend_fake_mode
15031503

15041504
with self.restore_global_state():
1505-
compiled_fn = self.call_user_compiler(gm)
1505+
compiled_fn = self.call_user_ 8000 compiler(gm, self.example_inputs())
15061506

15071507
from torch.fx._lazy_graph_module import _LazyGraphModule
15081508

@@ -1536,11 +1536,6 @@ def compile_and_call_fx_graph(self, tx, rv, root):
15361536
if specializations := old_fake_mode.shape_env.specializations:
15371537
specialization_guards = []
15381538
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
1539-
preserved_graphargs = [
1540-
replace(node.meta["grapharg"], _example=None)
1541-
for node in self.placeholders
1542-
]
1543-
preserved_tracing_context = torch._guards.TracingContext.get()
15441539
sources = [a.source for a in self.graphargs]
15451540
for specialization in specializations:
15461541
source_index = sources.index(specialization.source)
@@ -1576,23 +1571,15 @@ def specialized_dispatch(*args, **kwargs):
15761571
*args, **kwargs
15771572
)
15781573

1579-
for node, grapharg, arg in zip(
1580-
self.placeholders, preserved_graphargs, args
1574+
with self.shape_env.patch_source_specialization(
1575+
specialization.source, specialization.check_fn
15811576
):
1582-
node.meta["grapharg"] = replace(grapharg, _example=arg)
1583-
1584-
with tracing(preserved_tracing_context):
1585-
shape_env = (
1586-
preserved_tracing_context.fake_mode.shape_env
1577+
# Modify gm so AOTAutogradCache key changes per specialization
1578+
gm.meta["specialization"] = specialization
1579+
specialization_cache[specialization] = (
1580+
self.call_user_compiler(gm, args)
15871581
)
1588-
with shape_env.patch_source_specialization(
1589-
specialization.source, specialization.check_fn
1590-
):
1591-
# Modify gm so AOTAutogradCache key changes per specialization
1592-
gm.meta["specialization"] = specialization
1593-
specialization_cache[specialization] = (
1594-
self.call_user_compiler(gm)
1595-
)
1582+
15961583
return specialization_cache[specialization](*args, **kwargs)
15971584
return compiled_fn(*args, **kwargs)
15981585

@@ -1612,16 +1599,20 @@ def placeholders(self) -> list[fx.Node]:
16121599
def graphargs(self) -> list[GraphArg]:
16131600
return [node.meta["grapharg"] for node in self.placeholders]
16141601

1615-
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1602+
def call_user_compiler(
1603+
self, gm: fx.GraphModule, example_inputs: Sequence[Any]
1604+
) -> CompiledFn:
16161605
with dynamo_timed(
16171606
"OutputGraph.call_user_compiler",
16181607
phase_name="backend_compile",
16191608
log_pt2_compile_event=True,
16201609
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
16211610
):
1622-
return self._call_user_compiler(gm)
1611+
return self._call_user_compiler(gm, example_inputs)
16231612

1624-
def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1613+
def _call_user_compiler(
1614+
self, gm: fx.GraphModule, example_inputs: Sequence[Any]
1615+
) -> CompiledFn:
16251616
assert self.compiler_fn is not None
16261617
tot = 0
16271618
placeholders = []
@@ -1632,10 +1623,11 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
16321623
placeholders.append(node)
16331624
increment_op_count(tot)
16341625
for pl in placeholders:
1635-
arg = pl.meta["grapharg"]
1636-
# TODO: Why isn't this stored in meta :think:
1637-
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
1638-
pl._dynamo_source = arg.source
1626+
if not hasattr(pl, "_dynamo_source"):
1627+
arg = pl.meta["grapharg"]
1628+
# TODO: Why isn't this stored in meta :think:
1629+
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
1630+
pl._dynamo_source = arg.source
16391631

16401632
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
16411633
gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
@@ -1651,7 +1643,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
16511643
compiler_fn = self.compiler_fn
16521644
if config.verify_correctness:
16531645
compiler_fn = WrapperBackend(compiler_fn)
1654-
compiled_fn = compiler_fn(gm, self.example_inputs())
1646+
compiled_fn = compiler_fn(gm, example_inputs)
16551647
_step_logger()(logging.INFO, f"done compiler function {name}")
16561648
assert callable(compiled_fn), "compiler_fn did not return callable"
16571649
except (TensorifyScalarRestartAnalysis, ShortenTraceback):

torch/fx/experimental/symbolic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3634,7 +3634,7 @@ def patch_source_specialization(
36343634
"""
36353635
name = source.name()
36363636
sym = self.source_to_var[name]
3637-
expr = check_fn(sym)
3637+
expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr
36383638
new_axioms = dict(self.get_implications(self.simplify(expr)))
36393639
self.axioms.update(new_axioms)
36403640
try:

0 commit comments

Comments
 (0)
0