8000 [multigraph] use specializations in compile_and_call_fx_graph by bobrenjc93 · Pull Request #153449 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[multigraph] use specializations in compile_and_call_fx_graph #153449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
b91ff71
[multigraph] use specializations in compile_and_call_fx_graph
bobrenjc93 May 13, 2025
30d47f1
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 13, 2025
14e69f8
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 13, 2025
88aecdc
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 13, 2025
3fcd84b
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 13, 2025
214277f
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 14, 2025
bf1809a
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 14, 2025
fa363b2
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 14, 2025
96c5fae
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 14, 2025
bf7c26a
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 14, 2025
51c25ee
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 14, 2025
0dd757f
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 14, 2025
d83b329
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 29, 2025
b0cef62
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 29, 2025
312d49d
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 29, 2025
d0eba02
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 29, 2025
8082547
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 29, 2025
be5ef29
Update on "[multigraph] use specializations in compile_and_call_fx_gr…
bobrenjc93 May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/fx.experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ torch.fx.experimental.symbolic_shapes
PropagateUnbackedSymInts
DivideByKey
InnerTensorKey
Specialization

hint_int
is_concrete_int
Expand Down
44 changes: 44 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10505,6 +10505,50 @@ def f(x):
self.assertEqual(out_ref.stride(), out_test.stride())
self.assertEqual(x_ref, x_test)

@requires_gpu()
@skip_if_not_triton
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
def test_inductor_multiple_specializations(self):
from triton.testing import do_bench

@torch.compile(
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
dynamic=False,
)
def inductor_matmul(a, b):
torch._check(a.shape[0] == b.shape[1])
return (m, torch.mm(a, b))

m = 16
k = 1280
dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16)
torch._dynamo.decorators.mark_dynamic(
dynamic_a,
0,
)
torch._dynamo.decorators.mark_dynamic(
dynamic_specialized_a,
0,
specialize_on=[lambda x0: x0 == 16],
)
torch._dynamo.decorators.mark_dynamic(
b,
1,
)
dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b))
torch._dynamo.reset()
dynamic_specialized = do_bench(
lambda: inductor_matmul(dynamic_specialized_a, b)
)
self.assertGreaterEqual(dynamic, dynamic_specialized)

@requires_gpu()
def test_stride_preservation_with_stride_modifying_fx_pass(self):
def f(x):
Expand Down
87 changes: 75 additions & 12 deletions torch/_dynamo/output_graph.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
import torch.distributed as dist
import torch.nn
import torch.utils._pytree as pytree
from torch import fx
from torch import fx, Tensor
from torch._C._dynamo import guards
from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
from torch._guards import (
CompileContext,
Expand All @@ -61,6 +62,7 @@
guard_scalar,
is_symbolic,
ShapeEnv,
Specialization,
)
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
from torch.multiprocessing.reductions import StorageWeakRef
Expand Down Expand Up @@ -158,6 +160,8 @@
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")

RootGuardManager = guards.RootGuardManager


@dataclass(frozen=True)
class VariableTrackerCacheKey:
Expand Down Expand Up @@ -1675,7 +1679,7 @@ def compile_and_call_fx_graph(self, tx, rv, root):
self.tracing_context.fake_mode = backend_fake_mode

with self.restore_global_state():
compiled_fn = self.call_user_compiler(gm)
compiled_fn = self.call_user_compiler(gm, self.example_inputs())

from torch.fx._lazy_graph_module import _LazyGraphModule

Expand Down Expand Up @@ -1705,8 +1709,62 @@ def compile_and_call_fx_graph(self, tx, rv, root):
)

counters["stats"]["unique_graphs"] += 1
# This is safe because we pre-process name to be unique
self.install_global_unsafe(name, compiled_fn)
if specializations := old_fake_mode.shape_env.specializations:
specialization_guards = []
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
sources = [a.source for a in self.graphargs]
for specialization in specializations:
source_index = sources.index(specialization.source)
check_fn_source = inspect.getsource(specialization.check_fn).strip()
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
specialization.check_fn,
[check_fn_source],
)

log.debug(
"Compiling backend specialized graph with specialization=%s",
check_fn_source,
)

specialization_guards.append(
(
functools.partial(
lambda idx, args, check_fn=check_fn: check_fn(
args[idx]
),
source_index,
),
specialization,
)
)

@torch._dynamo.disable(reason="do not trace Dynamo-compiled graph")
def specialized_dispatch(*args, **kwargs):
for check_fn, specialization in specialization_guards:
if check_fn(args):
if specialization in specialization_cache:
return specialization_cache[specialization](
*args, **kwargs
)

with self.shape_env.patch_source_specialization(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the shape_env shared between Specializations? (it looks like yes).

If so, will this cause issues? Like if we have s0 == s1 and a specialization that says s0 == 2, does this cause s1 to always be 2 in subsequent Specializations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

patch_source_specialization is a context manager that temporarily adds axioms or replacements, keeping track of which ones it introduced and removing them when the context ends. This ensures that each specialization is completely independent and does not interfere with others.

Copy link
Contributor
@zou3519 zou3519 May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to get into a situation where:

  1. we add new axioms
  2. a pass later on in Inductor or AOTAutograd adds some more things to the shape env
  3. we remove the added axioms, but the things added in (2) stick around in the shape env and pollutes future specializations?

My naive take is that each Specialization would want its own ShapeEnv

Copy link
Contributor Author
@bobrenjc93 bobrenjc93 May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard for me to imagine a case where we will add additional guards/replacements on a subsequent more specialized graph than the first generic graph. The tricky thing with ShapeEnv is that it's not serializable because guards aren't serializable so it's unclear how we could maintain multiple parallel ShapeEnvs.

I just updated the PR to freeze the ShapeEnv after patching in the axioms/replacements so in the worst case we will fail the compile instead of silent incorrectness.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline discussion says: this may be a problem, we probably do want to clone the ShapeEnv, we will figure out how to resolve this in a followup.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I was imagining separate shape envs as well. maybe @ezyang can advise

specialization.source, specialization.check_fn
):
# Modify gm so AOTAutogradCache key changes per specialization
gm.meta["specialization"] = specialization
example_inputs: list[Tensor] = list(args)
specialization_cache[specialization] = (
self.call_user_compiler(gm, example_inputs)
)

return specialization_cache[specialization](*args, **kwargs)
return compiled_fn(*args, **kwargs)

# This is safe because we pre-process name to be unique
self.install_global_unsafe(name, specialized_dispatch)
else:
# This is safe because we pre-process name to be unique
self.install_global_unsafe(name, compiled_fn)

assert self.root_tx is not None
cg = PyCodegen(self.root_tx)
Expand All @@ -1721,7 +1779,9 @@ def placeholders(self) -> list[fx.Node]:
def graphargs(self) -> list[GraphArg]:
return [node.meta["grapharg"] for node in self.placeholders]

def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
def call_user_compiler(
self, gm: fx.GraphModule, example_inputs: list[Tensor]
) -> CompiledFn:
with dynamo_timed(
"OutputGraph.call_user_compiler",
phase_name="backend_compile",
Expand All @@ -1730,9 +1790,11 @@ def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
waitcounter_name_override="compile_aot_autograd",
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
):
return self._call_user_compiler(gm)
return self._call_user_compiler(gm, example_inputs)

def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
def _call_user_compiler(
self, gm: fx.GraphModule, example_inputs: list[Tensor]
) -> CompiledFn:
assert self.compiler_fn is not None
tot = 0
placeholders = []
Expand All @@ -1743,10 +1805,11 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
placeholders.append(node)
increment_op_count(tot)
for pl in placeholders:
arg = pl.meta["grapharg"]
# TODO: Why isn't this stored in meta :think:
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
pl._dynamo_source = arg.source
if not hasattr(pl, "_dynamo_source"):
arg = pl.meta["grapharg"]
# TODO: Why isn't this stored in meta :think:
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
pl._dynamo_source = arg.source

# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
Expand All @@ -1762,7 +1825,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
compiler_fn = self.compiler_fn
if config.verify_correctness:
compiler_fn = WrapperBackend(compiler_fn)
compiled_fn = compiler_fn(gm, self.example_inputs())
compiled_fn = compiler_fn(gm, example_inputs)
_step_logger()(logging.INFO, f"done compiler function {name}")
assert callable(compiled_fn), "compiler_fn did not return callable"
except (TensorifyScalarRestartAnalysis, ShortenTraceback):
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3134,6 +3134,7 @@ def update_dim2constraint(dim, constraint_range, name):
dynamic_strides = []
constraint_sizes = []
constraint_strides = []
specialize_on = []
for i in range(e.dim()):
# NB: mark dynamic has precedence over static
marked_strict_unbacked = i in getattr(
Expand All @@ -3144,6 +3145,8 @@ def update_dim2constraint(dim, constraint_range, name):
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
marked_static = i in getattr(e, "_dynamo_static_indices", set())

specialize_on.append(getattr(e, "_specialize_on", {}).get(i, []))

# Reflect the user directive in the frame_state
# For dynamic, apply None always

Expand Down Expand Up @@ -3271,6 +3274,7 @@ def update_dim2constraint(dim, constraint_range, name):
dynamic_strides=dynamic_strides,
constraint_sizes=constraint_sizes,
constraint_strides=constraint_strides,
specialize_on=specialize_on,
view_base_context=view_base_context,
tensor_source=source,
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
Expand Down
8 changes: 5 additions & 3 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ def fakify(
constraint_sizes[i] = RelaxedUnspecConstraint(warn_only=False) # type: ignore[call-overload]
else:
dynamic_sizes.append(DimDynamic.STATIC)
symbolic_context = StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
symbolic_context: StatelessSymbolicContext = ( # make mypy happy
StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
)
)
t_id = id(t)
assert mode.shape_env is not None
Expand Down
Loading
Loading
0