8000 Update on "thread through specialization to compile_fx" · pytorch/pytorch@7982534 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7982534

Browse files
committed
Update on "thread through specialization to compile_fx"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
1 parent 67d9211 commit 7982534

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

torch/_dynamo/output_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,7 @@ def compile_and_call_fx_graph(self, tx, rv, root):
14981498

14991499
compiled_fns = []
15001500
with self.restore_global_state():
1501-
for specialization in backend_specializations
1501+
for specialization in old_fake_mode.shape_env.backend_specializations
15021502
compiled_fns.append(self.call_user_compiler(modified_gm, specialization))
15031503

15041504
from torch.fx._lazy_graph_module import _LazyGraphModule

torch/fx/experimental/symbolic_shapes.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,12 @@ def find_symbol_binding_fx_nodes(
931931
return r
932932

933933

934+
@dataclass
935+
class BackendSpecialization:
936+
symbol: sympy.Symbol
937+
hint: int
938+
specialization: Callable
939+
934940
# Analogous to ConvertIntSource
935941
@dataclass(frozen=True)
936942
class ConvertIntKey:
@@ -3561,6 +3567,8 @@ def _init(
35613567

35623568
self.trace_asserts = trace_asserts
35633569

3570+
self.backend_specializations = []
3571+
35643572
from torch.fx.experimental.validator import translation_validation_enabled
35653573

35663574
self._translation_validation_enabled = translation_validation_enabled()
@@ -4044,6 +4052,11 @@ def _produce_dyn_sizes_from_int_tuple(
40444052
do_not_specialize_zero_one=config.backed_size_oblivious,
40454053
symbolic_context=symbolic_context,
40464054
)
4055+
for specialization in symbolic_context.backend_specializations:
4056+
self.backend_specializations.append(BackendSpecialization(
4057+
sym,
4058+
*specialization,
4059+
))
40474060
if (
40484061
config.backed_size_oblivious
40494062
and isinstance(sym, sympy.Symbol) # could be static

0 commit comments

Comments
 (0)
0