8000 wip · pytorch/pytorch@0a46e6d · GitHub
[go: up one dir, main page]

Skip to content

Commit 0a46e6d

Browse files
committed
wip
[ghstack-poisoned]
1 parent 92d66de commit 0a46e6d

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

torch/_dynamo/codegen.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sys
1919
import types
2020
from collections import Counter
21-
from typing import Optional, TYPE_CHECKING, Union
21+
from typing import Optional, TYPE_CHECKING, Union, Any
2222

2323
import torch.nn
2424
from torch.utils._ordered_set import OrderedSet
@@ -597,12 +597,14 @@ def mark_source_temp(self, source: Source) -> None:
597597
if source not in self.tempvars:
598598
self.tempvars[source] = None
599599

600-
def make_specialized_call_generated_code(
600+
def make_call_specialized_code(
601601
self,
602602
fn_name: str,
603-
specializations: list[tuple[str, BackendSpecialization]],
603+
specializations: list[tuple[str, BackendSpecialization, Any]],
604604
) -> None:
605605
"""Try specializations in order; fall back to fn_name if none match"""
606+
from .variables.builder import (GraphArg)
607+
606608
graphargs = self.tx.output.graphargs
607609
seen_sources: OrderedSet[Source] = OrderedSet()
608610

@@ -620,14 +622,13 @@ def collect_temp_source(source):
620622
if arg.source is not None:
621623
collect_temp_source(arg.source)
622624

623-
for fn_spec_name, spec in specializations:
624-
# Reconstruct the source expression to evaluate the specialization condition
625-
self.call_reconstruct(GraphArg(source=spec.source, example_value=None))
626-
self.extend_output(self.create_store_var("spec_value"))
625+
for fn_spec_name, spec, _ in specializations:
626+
self.call_reconstruct(spec.source)
627+
self.extend_output(self.create_store("spec_value"))
627628

628629
# Load the specialization function and call it with spec_value
629-
self.extend_output(self.create_load_const(spec.specialization))
630-
self.extend_output(self.create_load_var("spec_value"))
630+
self.extend_output(self.create_load_const(spec.check_fn))
631+
self.extend_output(self.create_load("spec_value"))
631632
self.extend_output(create_call_function(1, False))
632633

633634
skip_label = self.new_block()

torch/_dynamo/output_graph.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,11 +1496,11 @@ def compile_and_call_fx_graph(self, tx, rv, root):
14961496
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
14971497
self.tracing_context.fake_mode = backend_fake_mode
14981498

1499-
specialized_compiled_fns = []
1499+
specialized_compiles = []
15001500
with self.restore_global_state():
15011501
compiled_fn = self.call_user_compiler(gm)
15021502
for specialization in old_fake_mode.shape_env.backend_specializations:
1503-
specialized_compiled_fns.append((
1503+
specialized_compiles.append((
15041504
unique_id("__specialized_compiled_fn"),
15051505
specialization,
15061506
self.call_user_compiler(gm, specialization=specialization)
@@ -1539,10 +1539,10 @@ def compile_and_call_fx_graph(self, tx, rv, root):
15391539

15401540
cg = PyCodegen(tx)
15411541

1542-
if specialized_compiled_fns:
1543-
for name, specialized_compiled_fn in specialized_compiled_fns:
1544-
self.install_global_unsafe(name, specialized_compiled_fn)
1545-
cg.make_call_specialized_code(name, specialized_compiled_fns)
1542+
if specialized_compiles:
1543+
for fn_name, _, specialized_compiled_fn in specialized_compiles:
1544+
self.install_global_unsafe(fn_name, specialized_compiled_fn)
1545+
cg.make_call_specialized_code(name, specialized_compiles)
15461546
else:
15471547
cg.make_call_generated_code(name)
15481548
return cg.get_instructions()

torch/fx/experimental/symbolic_shapes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ def find_symbol_binding_fx_nodes(
937937
return r
938938

939939

940-
@dataclass
940+
@dataclass(frozen=True)
941941
class BackendSpecialization:
942942
source: TensorPropertySource
943943
hint: int
@@ -3579,7 +3579,7 @@ def _init(
35793579

35803580
self.trace_asserts = trace_asserts
35813581

3582-
self.backend_specializations = []
3582+
self.backend_specializations = set()
35833583

35843584
from torch.fx.experimental.validator import translation_validation_enabled
35853585

@@ -4065,7 +4065,7 @@ def _produce_dyn_sizes_from_int_tuple(
40654065
symbolic_context=symbolic_context,
40664066
)
40674067
for specialization in symbolic_context.backend_specializations[i]:
4068-
self.backend_specializations.append(BackendSpecialization(
4068+
self.backend_specializations.add(BackendSpecialization(
40694069
TensorPropertySource(source, TensorProperty.SIZE, i),
40704070
*specialization,
40714071
))

0 commit comments

Comments
 (0)
0