diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index dc3e96614f1493..ee9ea2234817ed 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -18,7 +18,7 @@ import sys import types from collections import Counter -from typing import Optional, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING, Union, Any import torch.nn from torch.utils._ordered_set import OrderedSet @@ -597,12 +597,14 @@ def mark_source_temp(self, source: Source) -> None: if source not in self.tempvars: self.tempvars[source] = None - def make_specialized_call_generated_code( + def make_call_specialized_code( self, fn_name: str, - specializations: list[tuple[str, BackendSpecialization]], + specializations: list[tuple[str, BackendSpecialization, Any]], ) -> None: """Try specializations in order; fall back to fn_name if none match""" + from .variables.builder import (GraphArg) + graphargs = self.tx.output.graphargs seen_sources: OrderedSet[Source] = OrderedSet() @@ -620,14 +622,13 @@ def collect_temp_source(source): if arg.source is not None: collect_temp_source(arg.source) - for fn_spec_name, spec in specializations: - # Reconstruct the source expression to evaluate the specialization condition - self.call_reconstruct(GraphArg(source=spec.source, example_value=None)) - self.extend_output(self.create_store_var("spec_value")) + for fn_spec_name, spec, _ in specializations: + self.call_reconstruct(spec.source) + self.extend_output(self.create_store("spec_value")) # Load the specialization function and call it with spec_value - self.extend_output(self.create_load_const(spec.specialization)) - self.extend_output(self.create_load_var("spec_value")) + self.extend_output(self.create_load_const(spec.check_fn)) + self.extend_output(self.create_load("spec_value")) self.extend_output(create_call_function(1, False)) skip_label = self.new_block() diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 1f9fbe362a89f5..5394ea0ff19680 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1496,11 +1496,11 @@ def compile_and_call_fx_graph(self, tx, rv, root): # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode self.tracing_context.fake_mode = backend_fake_mode - specialized_compiled_fns = [] + specialized_compiles = [] with self.restore_global_state(): compiled_fn = self.call_user_compiler(gm) for specialization in old_fake_mode.shape_env.backend_specializations: - specialized_compiled_fns.append(( + specialized_compiles.append(( unique_id("__specialized_compiled_fn"), specialization, self.call_user_compiler(gm, specialization=specialization) @@ -1539,10 +1539,10 @@ def compile_and_call_fx_graph(self, tx, rv, root): cg = PyCodegen(tx) - if specialized_compiled_fns: - for name, specialized_compiled_fn in specialized_compiled_fns: - self.install_global_unsafe(name, specialized_compiled_fn) - cg.make_call_specialized_code(name, specialized_compiled_fns) + if specialized_compiles: + for fn_name, _, specialized_compiled_fn in specialized_compiles: + self.install_global_unsafe(fn_name, specialized_compiled_fn) + cg.make_call_specialized_code(name, specialized_compiles) else: cg.make_call_generated_code(name) return cg.get_instructions() diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index f39c762c67c916..d0f18822a635a9 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -937,7 +937,7 @@ def find_symbol_binding_fx_nodes( return r -@dataclass +@dataclass(frozen=True) class BackendSpecialization: source: TensorPropertySource hint: int @@ -3579,7 +3579,7 @@ def _init( self.trace_asserts = trace_asserts - self.backend_specializations = [] + self.backend_specializations = set() from torch.fx.experimental.validator import translation_validation_enabled @@ -4065,7 +4065,7 @@ def _produce_dyn_sizes_from_int_tuple( symbolic_context=symbolic_context, ) for specialization in symbolic_context.backend_specializations[i]: - self.backend_specializations.append(BackendSpecialization( + self.backend_specializations.add(BackendSpecialization( TensorPropertySource(source, TensorProperty.SIZE, i), *specialization, ))