8000 [multigraph] use specializations in compile_and_call_fx_graph (#153449) · ROCm/pytorch@d08243e · GitHub
[go: up one dir, main page]

Skip to content

Commit d08243e

Browse files
bobrenjc93iupaikov-amd
authored andcommitted
[multigraph] use specializations in compile_and_call_fx_graph (pytorch#153449)
The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM which does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes:** 1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes (this PR):** 1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we will lazily (more on this later) invoke `call_user_compiler` up to N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler. 3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. Instead of doing all of the specialization compiles up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with pytorch#152596 and found around a 50% reduction when dispatching to the specialized regions: ![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73) Pull Request resolved: pytorch#153449 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#153433
1 parent 336a644 commit d08243e

File tree

6 files changed

+218
-17
lines changed

6 files changed

+218
-17
lines changed

docs/source/fx.experimental.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ torch.fx.experimental.symbolic_shapes
3131
PropagateUnbackedSymInts
3232
DivideByKey
3333
InnerTensorKey
34+
Specialization
3435

3536
hint_int
3637
is_concrete_int

test/inductor/test_torchinductor.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10505,6 +10505,50 @@ def f(x):
1050510505
self.assertEqual(out_ref.stride(), out_test.stride())
1050610506
self.assertEqual(x_ref, x_test)
1050710507

10508+
@requires_gpu()
10509+
@skip_if_not_triton
10510+
@unittest.skipIf(
10511+
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
10512+
)
10513+
def test_inductor_multiple_specializations(self):
10514+
from triton.testing import do_bench
10515+
10516+
@torch.compile(
10517+
options={
10518+
"max_autotune": True,
10519+
"max_autotune_gemm_backends": "TRITON",
10520+
},
10521+
dynamic=False,
10522+
)
10523+
def inductor_matmul(a, b):
10524+
torch._check(a.shape[0] == b.shape[1])
10525+
return (m, torch.mm(a, b))
10526+
10527+
m = 16
10528+
k = 1280
10529+
dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
10530+
dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
10531+
b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16)
10532+
torch._dynamo.decorators.mark_dynamic(
10533+
dynamic_a,
10534+
0,
10535+
)
10536+
torch._dynamo.decorators.mark_dynamic(
10537+
dynamic_specialized_a,
10538+
0,
10539+
specialize_on=[lambda x0: x0 == 16],
10540+
)
10541+
torch._dynamo.decorators.mark_dynamic(
10542+
b,
10543+
1,
10544+
)
10545+
dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b))
10546+
torch._dynamo.reset()
10547+
dynamic_specialized = do_bench(
10548+
lambda: inductor_matmul(dynamic_specialized_a, b)
10549+
)
10550+
self.assertGreaterEqual(dynamic, dynamic_specialized)
10551+
1050810552
@requires_gpu()
1050910553
def test_stride_preservation_with_stride_modifying_fx_pass(self):
1051010554
def f(x):

torch/_dynamo/output_graph.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
import torch.distributed as dist
4444
import torch.nn
4545
import torch.utils._pytree as pytree
46-
from torch import fx
46+
from torch import fx, Tensor
47+
from torch._C._dynamo import guards
4748
from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
4849
from torch._guards import (
4950
CompileContext,
@@ -61,6 +62,7 @@
6162
guard_scalar,
6263
is_symbolic,
6364
ShapeEnv,
65+
Specialization,
6466
)
6567
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
6668
from torch.multiprocessing.reductions import StorageWeakRef
@@ -158,6 +160,8 @@
158160
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
159161
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
160162

163+
RootGuardManager = guards.RootGuardManager
164+
161165

162166
@dataclass(frozen=True)
163167
class VariableTrackerCacheKey:
@@ -1675,7 +1679,7 @@ def compile_and_call_fx_graph(self, tx, rv, root):
16751679
self.tracing_context.fake_mode = backend_fake_mode
16761680

16771681
with self.restore_global_state():
1678-
compiled_fn = self.call_user_compiler(gm)
1682+
compiled_fn = self.call_user_compiler(gm, self.example_inputs())
16791683

16801684
from torch.fx._lazy_graph_module import _LazyGraphModule
16811685

@@ -1705,8 +1709,62 @@ def compile_and_call_fx_graph(self, tx, rv, root):
17051709
)
17061710

17071711
counters["stats"]["unique_graphs"] += 1
1708-
# This is safe because we pre-process name to be unique
1709-
self.install_global_unsafe(name, compiled_fn)
1712+
if specializations := old_fake_mode.shape_env.specializations:
1713+
specialization_guards = []
1714+
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
1715+
sources = [a.source for a in self.graphargs]
1716+
for specialization in specializations:
1717+
source_index = sources.index(specialization.source)
1718+
check_fn_source = inspect.getsource(specialization.check_fn).strip()
1719+
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
1720+
specialization.check_fn,
1721+
[check_fn_source],
1722+
)
1723+
1724+
log.debug(
1725+
"Compiling backend specialized graph with specialization=%s",
1726+
check_fn_source,
1727+
)
1728+
1729+
specialization_guards.append(
1730+
(
1731+
functools.partial(
1732+
lambda idx, args, check_fn=check_fn: check_fn(
1733+
args[idx]
1734+
),
1735+
source_index,
1736+
),
1737+
specialization,
1738+
)
1739+
)
1740+
1741+
@torch._dynamo.disable(reason="do not trace Dynamo-compiled graph")
1742+
def specialized_dispatch(*args, **kwargs):
1743+
for check_fn, specialization in specialization_guards:
1744+
if check_fn(args):
1745+
if specialization in specialization_cache:
1746+
return specialization_cache[specialization](
1747+
*args, **kwargs
1748+
)
1749+
1750+
with self.shape_env.patch_source_specialization(
1751+
specialization.source, specialization.check_fn
1752+
):
1753+
# Modify gm so AOTAutogradCache key changes per specialization
1754+
gm.meta["specialization"] = specialization
1755+
example_inputs: list[Tensor] = list(args)
1756+
specialization_cache[specialization] = (
1757+
self.call_user_compiler(gm, example_inputs)
1758+
)
1759+
1760+
return specialization_cache[specialization](*args, **kwargs)
1761+
return compiled_fn(*args, **kwargs)
1762+
1763+
# This is safe because we pre-process name to be unique
1764+
self.install_global_unsafe(name, specialized_dispatch)
1765+
else:
1766+
# This is safe because we pre-process name to be unique
1767+
self.install_global_unsafe(name, compiled_fn)
17101768

17111769
assert self.root_tx is not None
17121770
cg = PyCodegen(self.root_tx)
@@ -1721,7 +1779,9 @@ def placeholders(self) -> list[fx.Node]:
17211779
def graphargs(self) -> list[GraphArg]:
17221780
return [node.meta["grapharg"] for node in self.placeholders]
17231781

1724-
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1782+
def call_user_compiler(
1783+
self, gm: fx.GraphModule, example_inputs: list[Tensor]
1784+
) -> CompiledFn:
17251785
with dynamo_timed(
17261786
"OutputGraph.call_user_compiler",
17271787
phase_name="backend_compile",
@@ -1730,9 +1790,11 @@ def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
17301790
waitcounter_name_override="compile_aot_autograd",
17311791
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
17321792
):
1733-
return self._call_user_compiler(gm)
1793+
return self._call_user_compiler(gm, example_inputs)
17341794

1735-
def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1795+
def _call_user_compiler(
1796+
self, gm: fx.GraphModule, example_inputs: list[Tensor]
1797+
) -> CompiledFn:
17361798
assert self.compiler_fn is not None
17371799
tot = 0
17381800
placeholders = []
@@ -1743,10 +1805,11 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
17431805
placeholders.append(node)
17441806
increment_op_count(tot)
17451807
for pl in placeholders:
1746-
arg = pl.meta["grapharg"]
1747-
# TODO: Why isn't this stored in meta :think:
1748-
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
1749-
pl._dynamo_source = arg.source
1808+
if not hasattr(pl, "_dynamo_source"):
1809+
arg = pl.meta["g 7802 rapharg"]
1810+
# TODO: Why isn't this stored in meta :think:
1811+
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
1812+
pl._dynamo_source = arg.source
17501813

17511814
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
17521815
gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
@@ -1762,7 +1825,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
17621825
compiler_fn = self.compiler_fn
17631826
if config.verify_correctness:
17641827
compiler_fn = WrapperBackend(compiler_fn)
1765-
compiled_fn = compiler_fn(gm, self.example_inputs())
1828+
compiled_fn = compiler_fn(gm, example_inputs)
17661829
_step_logger()(logging.INFO, f"done compiler function {name}")
17671830
assert callable(compiled_fn), "compiler_fn did not return callable"
17681831
except (TensorifyScalarRestartAnalysis, ShortenTraceback):

torch/_dynamo/variables/builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3134,6 +3134,7 @@ def update_dim2constraint(dim, constraint_range, name):
31343134
dynamic_strides = []
31353135
constraint_sizes = []
31363136
constraint_strides = []
3137+
specialize_on = []
31373138
for i in range(e.dim()):
31383139
# NB: mark dynamic has precedence over static
31393140
marked_strict_unbacked = i in getattr(
@@ -3144,6 +3145,8 @@ def update_dim2constraint(dim, constraint_range, name):
31443145
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
31453146
marked_static = i in getattr(e, "_dynamo_static_indices", set())
31463147

3148+
specialize_on.append(getattr(e, "_specialize_on", {}).get(i, []))
3149+
31473150
# Reflect the user directive in the frame_state
31483151
# For dynamic, apply None always
31493152

@@ -3271,6 +3274,7 @@ def update_dim2constraint(dim, constraint_range, name):
32713274
dynamic_strides=dynamic_strides,
32723275
constraint_sizes=constraint_sizes,
32733276
constraint_strides=constraint_strides,
3277+
specialize_on=specialize_on,
32743278
view_base_context=view_base_context,
32753279
tensor_source=source,
32763280
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,

torch/_export/non_strict_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,11 @@ def fakify(
143143
constraint_sizes[i] = RelaxedUnspecConstraint(warn_only=False) # type: ignore[call-overload]
144144
else:
145145
dynamic_sizes.append(DimDynamic.STATIC)
146-
symbolic_context = StatelessSymbolicContext(
147-
dynamic_sizes=dynamic_sizes,
148-
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
146+
symbolic_context: StatelessSymbolicContext = ( # make mypy happy
147+
StatelessSymbolicContext(
148+
dynamic_sizes=dynamic_sizes,
149+
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
150+
)
149151
)
150152
t_id = id(t)
151153
assert mode.shape_env is not None

0 commit comments

Comments
 (0)
0