8000 [multigraph] use specializations in compile_and_call_fx_graph · pytorch/pytorch@f4dd47a · GitHub
[go: up one dir, main page]

Skip to content

Commit f4dd47a

Browse files
committed
[multigraph] use specializations in compile_and_call_fx_graph
ghstack-source-id: c2445e9 Pull Request resolved: #153449
1 parent 535fc62 commit f4dd47a

File tree

12 files changed

+214
-32
lines changed

12 files changed

+214
-32
lines changed

test/inductor/test_torchinductor.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10464,6 +10464,50 @@ def f(x):
1046410464
self.assertEqual(out_ref.stride(), out_test.stride())
1046510465
self.assertEqual(x_ref, x_test)
1046610466

10467+
@requires_gpu()
10468+
@skip_if_not_triton
10469+
@unittest.skipIf(
10470+
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
10471+
)
10472+
def test_inductor_multiple_specializations(self):
10473+
from triton.testing import do_bench
10474+
10475+
@torch.compile(
10476+
options={
10477+
"max_autotune": True,
10478+
"max_autotune_gemm_backends": "TRITON",
10479+
},
10480+
dynamic=False,
10481+
)
10482+
def inductor_matmul(a, b):
10483+
torch._check(a.shape[0] == b.shape[1])
10484+
return (m, torch.mm(a, b))
10485+
10486+
m = 16
10487+
k = 1280
10488+
dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
10489+
dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
10490+
b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16)
10491+
torch._dynamo.decorators.mark_dynamic(
10492+
dynamic_a,
10493+
0,
10494+
)
10495+
torch._dynamo.decorators.mark_dynamic(
10496+
dynamic_specialized_a,
10497+
0,
10498+
specialize_on=[lambda x0: x0 == 16],
10499+
)
10500+
torch._dynamo.decorators.mark_dynamic(
10501+
b,
10502+
1,
10503+
)
10504+
dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b))
10505+
torch._dynamo.reset()
10506+
dynamic_specialized = do_bench(
10507+
lambda: inductor_matmul(dynamic_specialized_a, b)
10508+
)
10509+
self.assertGreaterEqual(dynamic, dynamic_specialized)
10510+
1046710511
@requires_gpu()
1046810512
def test_stride_preservation_with_stride_modifying_fx_pass(self):
1046910513
def f(x):

torch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2359,10 +2359,10 @@ def apply_options(self, options: _Optional[dict[str, _Any]]):
23592359
)
23602360
self.config[attr_name] = val
23612361

2362-
def __call__(self, model_, inputs_):
2362+
def __call__(self, model_, inputs_, **kwargs):
23632363
from torch._inductor.compile_fx import compile_fx
23642364

2365-
return compile_fx(model_, inputs_, config_patches=self.config)
2365+
return compile_fx(model_, inputs_, config_patches=self.config, **kwargs)
23662366

23672367
def get_compiler_config(self):
23682368
from torch._inductor.compile_fx import get_patched_config_dict

torch/_dynamo/output_graph.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import sys
3434
import traceback
3535
import weakref
36-
from dataclasses import dataclass
36+
from dataclasses import dataclass, replace
3737
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
3838

3939
import sympy
@@ -44,6 +44,7 @@
4444
import torch.nn
4545
import torch.utils._pytree as pytree
4646
from torch import fx
47+
from torch._C._dynamo import guards
4748
from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
4849
from torch._guards import (
4950
CompileContext,
@@ -61,6 +62,7 @@
6162
guard_scalar,
6263
is_symbolic,
6364
ShapeEnv,
65+
Specialization,
6466
)
6567
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
6668
from torch.multiprocessing.reductions import StorageWeakRef
@@ -157,6 +159,8 @@
157159
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
158160
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
159161

162+
RootGuardManager = guards.RootGuardManager
163+
160164

161165
@dataclass(frozen=True)
162166
class VariableTrackerCacheKey:
@@ -1528,7 +1532,62 @@ def compile_and_call_fx_graph(self, tx, rv, root):
15281532

15291533
counters["stats"]["unique_graphs"] += 1
15301534
# This is safe because we pre-process name to be unique
1531-
self.install_global_unsafe(name, compiled_fn)
1535+
if specializations := old_fake_mode.shape_env.specializations:
1536+
specialization_guards = []
1537+
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
1538+
preserved_graphargs = [
1539+
replace(node.meta["grapharg"], _example=None)
1540+
for node in self.placeholders
1541+
]
1542+
sources = [a.source for a in self.graphargs]
1543+
for specialization in specializations:
1544+
source_index = sources.index(specialization.source)
1545+
check_fn_source = inspect.getsource(specialization.check_fn).strip()
1546+
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
1547+
specialization.check_fn,
1548+
[check_fn_source],
1549+
)
1550+
1551+
log.debug(
1552+
"Compiling backend specialized graph with specialization=%s",
1553+
check_fn_source,
1554+
)
1555+
1556+
specialization_guards.append(
1557+
(
1558+
functools.partial(
1559+
lambda idx, args, check_fn=check_fn: check_fn(
1560+
args[idx]
1561+
),
1562+
source_index,
1563+
),
1564+
specialization,
1565+
)
1566+
)
1567+
1568+
@torch._dynamo.disable(reason="do not trace Dynamo-compiled graph")
1569+
def specialized_dispatch(*args, **kwargs):
1570+
for check_fn, specialization in specialization_guards:
1571+
if check_fn(args):
1572+
if specialization in specialization_cache:
1573+
return specialization_cache[specialization](
1574+
*args, **kwargs
1575+
)
1576+
for node, grapharg, arg in zip(
1577+
self.placeholders, preserved_graphargs, args
1578+
):
1579+
node.meta["grapharg"] = replace(grapharg, _example=arg)
1580+
specialization_cache[specialization] = (
1581+
self.call_user_compiler(
1582+
gm, specialization=specialization
1583+
)
1584+
)
1585+
return specialization_cache[specialization](*args, **kwargs)
1586+
return compiled_fn(*args, **kwargs)
1587+
1588+
self.install_global_unsafe(name, specialized_dispatch)
1589+
else:
1590+
self.install_global_unsafe(name, compiled_fn)
15321591

15331592
cg = PyCodegen(tx)
15341593
cg.make_call_generated_code(name)
@@ -1542,16 +1601,16 @@ def placeholders(self) -> list[fx.Node]:
15421601
def graphargs(self) -> list[GraphArg]:
15431602
return [node.meta["grapharg"] for node in self.placeholders]
15441603

1545-
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1604+
def call_user_compiler(self, gm: fx.GraphModule, **kwargs) -> CompiledFn:
15461605
with dynamo_timed(
15471606
"OutputGraph.call_user_compiler",
15481607
phase_name="backend_compile",
15491608
log_pt2_compile_event=True,
15501609
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
15511610
):
1552-
return self._call_user_compiler(gm)
1611+
return self._call_user_compiler(gm, **kwargs)
15531612

1554-
def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1613+
def _call_user_compiler(self, gm: fx.GraphModule, **kwargs) -> CompiledFn:
15551614
assert self.compiler_fn is not None
15561615
tot = 0
15571616
placeholders = []
@@ -1581,7 +1640,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
15811640
compiler_fn = self.compiler_fn
15821641
if config.verify_correctness:
15831642
compiler_fn = WrapperBackend(compiler_fn)
1584-
compiled_fn = compiler_fn(gm, self.example_inputs())
1643+
compiled_fn = compiler_fn(gm, self.example_inputs(), **kwargs)
15851644
_step_logger()(logging.INFO, f"done compiler function {name}")
15861645
assert callable(compiled_fn), "compiler_fn did not return callable"
15871646
except (TensorifyScalarRestartAnalysis, ShortenTraceback):

torch/_dynamo/repro/after_dynamo.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def add_paths(exc):
110110
# Check for either accuracy (level 4) or other type of failures.
111111
if config.repro_level == 4:
112112
# Check Accuracy
113-
compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs)
113+
compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs, **kwargs)
114114
if _accuracy_fails(gm, example_inputs, compiler_fn):
115115
log.warning(
116116
"Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error."
@@ -125,7 +125,9 @@ def add_paths(exc):
125125
raise exc
126126
else:
127127
try:
128-
compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs)
128+
compiled_gm = compiler_fn(
129+
copy.deepcopy(gm), example_inputs, **kwargs
130+
)
129131
run_fwd_maybe_bwd(compiled_gm, example_inputs)
130132
except Exception as exc:
131133
log.warning(
@@ -147,7 +149,7 @@ def add_paths(exc):
147149
add_paths(exc)
148150
raise
149151
else:
150-
compiled_gm = compiler_fn(gm, example_inputs)
152+
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
151153

152154
return compiled_gm
153155

torch/_dynamo/variables/builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3053,6 +3053,7 @@ def update_dim2constraint(dim, constraint_range, name):
30533053
dynamic_strides = []
30543054
constraint_sizes = []
30553055
constraint_strides = []
3056+
specialize_on = []
30563057
for i in range(e.dim()):
30573058
# NB: mark dynamic has precedence over static
30583059
marked_strict_unbacked = i in getattr(
@@ -3063,6 +3064,8 @@ def update_dim2constraint(dim, constraint_range, name):
30633064
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
30643065
marked_static = i in getattr(e, "_dynamo_static_indices", set())
30653066

3067+
specialize_on.append(getattr(e, "_specialize_on", {}).get(i, []))
3068+
30663069
# Reflect the user directive in the frame_state
30673070
# For dynamic, apply None always
30683071

@@ -3182,6 +3185,7 @@ def update_dim2constraint(dim, constraint_range, name):
31823185
dynamic_strides=dynamic_strides,
31833186
constraint_sizes=constraint_sizes,
31843187
constraint_strides=constraint_strides,
3188+
specialize_on=specialize_on,
31853189
view_base_context=view_base_context,
31863190
tensor_source=source,
31873191
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,

torch/_export/non_strict_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,11 @@ def fakify(
143143
constraint_sizes[i] = RelaxedUnspecConstraint(warn_only=False) # type: ignore[call-overload]
144144
else:
145145
dynamic_sizes.append(DimDynamic.STATIC)
146-
symbolic_context = StatelessSymbolicContext(
147-
dynamic_sizes=dynamic_sizes,
148-
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
146+
symbolic_context: StatelessSymbolicContext = ( # make mypy happy
147+
StatelessSymbolicContext(
148+
dynamic_sizes=dynamic_sizes,
149+
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
150+
)
149151
)
150152
t_id = id(t)
151153
assert mode.shape_env is not None

torch/_functorch/aot_autograd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_pytree_subclasses_that_lose_info,
3232
make_fx,
3333
)
34-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
34+
from torch.fx.experimental.symbolic_shapes import ShapeEnv, Specialization
3535
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
3636

3737

@@ -489,6 +489,7 @@ def process_inputs(
489489
fake_mode: FakeTensorMode,
490490
shape_env: Optional[ShapeEnv],
491491
ignore_shape_env: bool = False,
492+
specialization: Optional[Specialization] = None,
492493
) -> FakifiedFlatArgs:
493494
with fake_mode:
494495

@@ -547,6 +548,7 @@ def convert(idx, x):
547548
symbolic_context=symbolic_context,
548549
source=source,
549550
trace=trace,
551+
specialization=specialization,
550552
)
551553
return result
552554

@@ -1084,6 +1086,7 @@ def aot_module_simplified(
10841086
cudagraphs: Optional[BoxedBool] = None,
10851087
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
10861088
ignore_shape_env: bool = False,
1089+
specialization: Optional[Specialization] = None,
10871090
) -> nn.Module:
10881091
"""
10891092
This is the simplified or low overhead version of aot_module. For frontends

torch/_inductor/compile_fx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127

128128
from torch._inductor.output_code import _StrideExprStr
129129
from torch._ops import OpOverload
130+
from torch.fx.experimental.symbolic_shapes import Specialization
130131

131132
from .ir import ExternKernelNode
132133

@@ -1914,6 +1915,7 @@ def compile_fx(
19141915
config_patches: Optional[dict[str, Any]] = None,
19151916
decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
19161917
ignore_shape_env: bool = False,
1918+
specialization: Optional[Specialization] = None,
19171919
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]:
19181920
"""
19191921
Main entry point for compiling given FX graph. Despite the fact that this
@@ -1939,6 +1941,7 @@ def compile_fx(
19391941
inner_compile=config.patch(config_patches)(inner_compile),
19401942
decompositions=decompositions,
19411943
ignore_shape_env=ignore_shape_env,
1944+
specialization=specialization,
19421945
)
19431946

19441947
# TODO: This probably shouldn't be a recursive call
@@ -1995,13 +1998,15 @@ def compile_fx(
19951998
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
19961999
decompositions=decompositions,
19972000
ignore_shape_env=ignore_shape_env,
2001+
specialization=specialization,
19982002
)
19992003

20002004
recursive_compile_fx = functools.partial(
20012005
compile_fx,
20022006
inner_compile=inner_compile,
20032007
decompositions=decompositions,
20042008
ignore_shape_env=ignore_shape_env,
2009+
specialization=specialization,
20052010
)
20062011

20072012
if not graph_returns_tuple(model_):
@@ -2332,6 +2337,7 @@ def bw_compiler(
23322337
cudagraphs=cudagraphs,
23332338
boxed_forward_device_index=forward_device,
23342339
ignore_shape_env=ignore_shape_env,
2340+
specialization=specialization,
23352341
)(model_, example_inputs_)
23362342
except ShortenTraceback as e:
23372343
# We will also shorten the traceback inside dynamo.

torch/_subclasses/fake_tensor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@
5959

6060
from torch._guards import Source
6161
from torch._ops import OpOverload
62-
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
62+
from torch.fx.experimental.symbolic_shapes import (
63+
ShapeEnv,
64+
Specialization,
65+
SymbolicContext,
66+
)
6367

6468
log = logging.getLogger(__name__)
6569

@@ -354,6 +358,7 @@ def from_real_tensor(
354358
source: Optional[Source] = None,
355359
symbolic_context: Optional[SymbolicContext] = None,
356360
trace: bool = True,
361+
specialization: Optional[Specialization] = None,
357362
) -> FakeTensor:
358363
# see note [Tensor Fakification and Symbol Caching]
359364
if not symbolic_context and not source and shape_env:
@@ -408,6 +413,7 @@ def mk_fake_tensor(
408413
source=source,
409414
symbolic_context=symbolic_context,
410415
trace=trace,
416+
specialization=specialization,
411417
)
412418
if out is NotImplemented:
413419
raise UnsupportedFakeTensorException("meta converter nyi")
@@ -2864,6 +2870,7 @@ def from_tensor(
28642870
source: Optional[Source] = None,
28652871
symbolic_context: Optional[SymbolicContext] = None,
28662872
trace: bool = True,
2873+
specialization: Optional[Specialization] = None,
28672874
) -> FakeTensor:
28682875
shape_env: Optional[ShapeEnv] = self.shape_env
28692876
if static_shapes is None:
@@ -2880,6 +2887,7 @@ def from_tensor(
28802887
source=source,
28812888
symbolic_context=symbolic_context,
28822889
trace=trace,
2890+
specialization=specialization,
28832891
)
28842892

28852893

0 commit comments

Comments
 (0)
0