diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 071a275b973e0..a7405872cf809 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10464,6 +10464,50 @@ def f(x): self.assertEqual(out_ref.stride(), out_test.stride()) self.assertEqual(x_ref, x_test) + @requires_gpu() + @skip_if_not_triton + @unittest.skipIf( + not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)" + ) + def test_inductor_multiple_specializations(self): + from triton.testing import do_bench + + @torch.compile( + options={ + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + }, + dynamic=False, + ) + def inductor_matmul(a, b): + torch._check(a.shape[0] == b.shape[1]) + return (m, torch.mm(a, b)) + + m = 16 + k = 1280 + dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16) + dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16) + torch._dynamo.decorators.mark_dynamic( + dynamic_a, + 0, + ) + torch._dynamo.decorators.mark_dynamic( + dynamic_specialized_a, + 0, + specialize_on=[lambda x0: x0 == 16], + ) + torch._dynamo.decorators.mark_dynamic( + b, + 1, + ) + dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b)) + torch._dynamo.reset() + dynamic_specialized = do_bench( + lambda: inductor_matmul(dynamic_specialized_a, b) + ) + self.assertGreaterEqual(dynamic, dynamic_specialized) + @requires_gpu() def test_stride_preservation_with_stride_modifying_fx_pass(self): def f(x): diff --git a/torch/__init__.py b/torch/__init__.py index d9d58b630060e..2304687ed846a 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2359,10 +2359,10 @@ def apply_options(self, options: _Optional[dict[str, _Any]]): ) self.config[attr_name] = val - def __call__(self, model_, inputs_): + def __call__(self, model_, inputs_, **kwargs): from torch._inductor.compile_fx import compile_fx - return compile_fx(model_, inputs_, config_patches=self.config) + return compile_fx(model_, inputs_, config_patches=self.config, **kwargs) def get_compiler_config(self): from torch._inductor.compile_fx import get_patched_config_dict diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index af872a1f207d2..01fc9d7842139 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -43,7 +43,8 @@ import torch.distributed as dist import torch.nn import torch.utils._pytree as pytree -from torch import fx +from torch import fx, Tensor +from torch._C._dynamo import guards from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis from torch._guards import ( CompileContext, @@ -61,6 +62,7 @@ guard_scalar, is_symbolic, ShapeEnv, + Specialization, ) from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from torch.multiprocessing.reductions import StorageWeakRef @@ -157,6 +159,8 @@ graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") +RootGuardManager = guards.RootGuardManager + @dataclass(frozen=True) class VariableTrackerCacheKey: @@ -1497,7 +1501,7 @@ def compile_and_call_fx_graph(self, tx, rv, root): self.tracing_context.fake_mode = backend_fake_mode with self.restore_global_state(): - compiled_fn = self.call_user_compiler(gm) + compiled_fn = self.call_user_compiler(gm, self.example_inputs()) from torch.fx._lazy_graph_module import _LazyGraphModule @@ -1528,7 +1532,60 @@ def compile_and_call_fx_graph(self, tx, rv, root): counters["stats"]["unique_graphs"] += 1 # This is safe because we pre-process name to be unique - self.install_global_unsafe(name, compiled_fn) + if specializations := old_fake_mode.shape_env.specializations: + specialization_guards = [] + specialization_cache: dict[Specialization, Callable[[Any], Any]] = {} + sources = [a.source for a in self.graphargs] + for specialization in specializations: + source_index = sources.index(specialization.source) + check_fn_source = inspect.getsource(specialization.check_fn).strip() + check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined] + specialization.check_fn, + [check_fn_source], + ) + + log.debug( + "Compiling backend specialized graph with specialization=%s", + check_fn_source, + ) + + specialization_guards.append( + ( + functools.partial( + lambda idx, args, check_fn=check_fn: check_fn( + args[idx] + ), + source_index, + ), + specialization, + ) + ) + + @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") + def specialized_dispatch(*args, **kwargs): + for check_fn, specialization in specialization_guards: + if check_fn(args): + if specialization in specialization_cache: + return specialization_cache[specialization]( + *args, **kwargs + ) + + with self.shape_env.patch_source_specialization( + specialization.source, specialization.check_fn + ): + # Modify gm so AOTAutogradCache key changes per specialization + gm.meta["specialization"] = specialization + example_inputs: list[Tensor] = list(args) + specialization_cache[specialization] = ( + self.call_user_compiler(gm, example_inputs) + ) + + return specialization_cache[specialization](*args, **kwargs) + return compiled_fn(*args, **kwargs) + + self.install_global_unsafe(name, specialized_dispatch) + else: + self.install_global_unsafe(name, compiled_fn) cg = PyCodegen(tx) cg.make_call_generated_code(name) @@ -1542,16 +1599,20 @@ def placeholders(self) -> list[fx.Node]: def graphargs(self) -> list[GraphArg]: return [node.meta["grapharg"] for node in self.placeholders] - def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: + def call_user_compiler( + self, gm: fx.GraphModule, example_inputs: list[Tensor] + ) -> CompiledFn: with dynamo_timed( "OutputGraph.call_user_compiler", phase_name="backend_compile", log_pt2_compile_event=True, dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us", ): - return self._call_user_compiler(gm) + return self._call_user_compiler(gm, example_inputs) - def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: + def _call_user_compiler( + self, gm: fx.GraphModule, example_inputs: list[Tensor] + ) -> CompiledFn: assert self.compiler_fn is not None tot = 0 placeholders = [] @@ -1562,10 +1623,11 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: placeholders.append(node) increment_op_count(tot) for pl in placeholders: - arg = pl.meta["grapharg"] - # TODO: Why isn't this stored in meta :think: - # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 - pl._dynamo_source = arg.source + if not hasattr(pl, "_dynamo_source"): + arg = pl.meta["grapharg"] + # TODO: Why isn't this stored in meta :think: + # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 + pl._dynamo_source = arg.source # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment] @@ -1581,7 +1643,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: compiler_fn = self.compiler_fn if config.verify_correctness: compiler_fn = WrapperBackend(compiler_fn) - compiled_fn = compiler_fn(gm, self.example_inputs()) + compiled_fn = compiler_fn(gm, example_inputs) _step_logger()(logging.INFO, f"done compiler function {name}") assert callable(compiled_fn), "compiler_fn did not return callable" except (TensorifyScalarRestartAnalysis, ShortenTraceback): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 092a8d1e6428d..cc2dd1648687c 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -3053,6 +3053,7 @@ def update_dim2constraint(dim, constraint_range, name): dynamic_strides = [] constraint_sizes = [] constraint_strides = [] + specialize_on = [] for i in range(e.dim()): # NB: mark dynamic has precedence over static marked_strict_unbacked = i in getattr( @@ -3063,6 +3064,8 @@ def update_dim2constraint(dim, constraint_range, name): marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) marked_static = i in getattr(e, "_dynamo_static_indices", set()) + specialize_on.append(getattr(e, "_specialize_on", {}).get(i, [])) + # Reflect the user directive in the frame_state # For dynamic, apply None always @@ -3182,6 +3185,7 @@ def update_dim2constraint(dim, constraint_range, name): dynamic_strides=dynamic_strides, constraint_sizes=constraint_sizes, constraint_strides=constraint_strides, + specialize_on=specialize_on, view_base_context=view_base_context, tensor_source=source, shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index f67c9b0c7eae9..15d13c414fe66 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -143,9 +143,11 @@ def fakify( constraint_sizes[i] = RelaxedUnspecConstraint(warn_only=False) # type: ignore[call-overload] else: dynamic_sizes.append(DimDynamic.STATIC) - symbolic_context = StatelessSymbolicContext( - dynamic_sizes=dynamic_sizes, - constraint_sizes=constraint_sizes, # type: ignore[arg-type] + symbolic_context: StatelessSymbolicContext = ( # make mypy happy + StatelessSymbolicContext( + dynamic_sizes=dynamic_sizes, + constraint_sizes=constraint_sizes, # type: ignore[arg-type] + ) ) t_id = id(t) assert mode.shape_env is not None diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index c384374d3da03..b168604e4bb1a 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -40,6 +40,7 @@ Any, Callable, cast, + Generic, NamedTuple, NoReturn, Optional, @@ -47,7 +48,7 @@ TypeVar, Union, ) -from typing_extensions import deprecated, TypeAlias, TypeGuard +from typing_extensions import deprecated, ParamSpec, TypeAlias, TypeGuard import torch import torch.fx @@ -102,6 +103,7 @@ import types from torch import Tensor + from torch._dynamo.source import TensorPropertySource from torch._subclasses.fake_tensor import FakeTensor from torch.types import BoolLikeType, FloatLikeType, IntLikeType @@ -167,6 +169,7 @@ class PendingUnbackedSymbolNotFound(RuntimeError): "is_accessor_node", "ValueRangesSLoc", "SymIntEqByExpr", + "Specialization", ] # FX node metadata keys for symbolic shape FX graph. @@ -931,6 +934,20 @@ def find_symbol_binding_fx_nodes( return r +@dataclass(frozen=True) +class Specialization: + """ + This class is used in multi-graph compilation contexts where we generate + multiple specialized graphs and dispatch to the appropriate one at runtime. + This allows us to optimize the trade-off between performance and generality + by creating specialized versions for common patterns (e.g., x.shape[0] % 16 == 0) + while maintaining a general fallback. + """ + + source: TensorPropertySource + check_fn: Callable + + # Analogous to ConvertIntSource @dataclass(frozen=True) class ConvertIntKey: @@ -1894,8 +1911,12 @@ class SymIntSymbolicContext(SymbolicContext): constraint: DimConstraint +_P1 = ParamSpec("_P1") +_T1 = TypeVar("_T1") + + @dataclass(frozen=True) -class StatelessSymbolicContext(SymbolicContext): +class StatelessSymbolicContext(Generic[_P1, _T1], SymbolicContext): """ Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. @@ -1906,6 +1927,7 @@ class StatelessSymbolicContext(SymbolicContext): dynamic_strides: DimList[DimDynamic] = None # type: ignore[assignment] constraint_sizes: DimList[DimConstraint] = None # type: ignore[assignment] constraint_strides: DimList[DimConstraint] = None # type: ignore[assignment] + specialize_on: Optional[list[list[Callable[_P1, _T1]]]] = None # If the tensor is a view, this should be populated for the base. It contains # information on how to allocate symbols when recursively fakeifying the base # during view fake-ification. @@ -1913,6 +1935,12 @@ class StatelessSymbolicContext(SymbolicContext): # TODO: add storage offset and stride symbolic_context def __post_init__(self) -> None: + if self.specialize_on is None: + object.__setattr__( + self, + "specialize_on", + [[]] * len(self.dynamic_sizes), + ) if self.dynamic_strides is None: object.__setattr__( self, @@ -3542,6 +3570,8 @@ def _init( self.trace_asserts = trace_asserts + self.specializations: OrderedSet[Specialization] = OrderedSet() + from torch.fx.experimental.validator import translation_validation_enabled self._translation_validation_enabled = translation_validation_enabled() @@ -3593,6 +3623,47 @@ def prefer_deferred_runtime_asserts_over_guards(self) -> bool: def allow_complex_guards_as_runtime_asserts(self) -> bool: return self.settings.allow_complex_guards_as_runtime_asserts + @contextmanager + def patch_source_specialization( + self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr] + ) -> Iterator[None]: + """ + Temporarily add symbol-level axioms to the ShapeEnv. This is useful when you want to "fork" + and have parallel universes of ShapeEnvs. For example, we use this when doing multi-graph + compile so we can support various graphs with varying levels of specializations. + + This context manager allows for temporarily adding constraints to the shape environment + based on a specialization function applied to a symbol associated with a source. + + Args: + source: The source of the symbol to specialize + check_fn: A function that takes a sympy Symbol and returns a sympy expression + representing a constraint/specialization to be applied + """ + name = source.name() + sym = self.source_to_var[name] + expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr + new_axioms = dict(self.get_implications(self.simplify(expr))) + added_replacements = {} + for axiom in new_axioms: + if ( + isinstance(axiom, sympy.Eq) + and isinstance(axiom.lhs, sympy.Symbol) + and isinstance(axiom.rhs, sympy.Integer) + and axiom.lhs not in self.replacements + ): + self.replacements[axiom.lhs] = axiom.rhs + added_replacements[axiom.lhs] = axiom.rhs + + self.axioms.update(new_axioms) + try: + yield + finally: + for k in new_axioms: + self.axioms.pop(k, None) + for k in added_replacements: + self.replacements.pop(k, None) + def check_equal(self, other: ShapeEnv) -> None: """Compare another ShapeEnv for equivalence""" # ShapeEnv fields that are not relevant for the outcome of @@ -4025,6 +4096,17 @@ def _produce_dyn_sizes_from_int_tuple( do_not_specialize_zero_one=config.backed_size_oblivious, symbolic_context=symbolic_context, ) + if ( + isinstance(symbolic_context, StatelessSymbolicContext) + and symbolic_context.specialize_on + ): + for specialization in symbolic_context.specialize_on[i]: + self.specializations.add( + Specialization( + TensorPropertySource(source, TensorProperty.SIZE, i), + specialization, + ) + ) if ( config.backed_size_oblivious and isinstance(sym, sympy.Symbol) # could be static