8000 Support whitelist of dynamic sources · pytorch/pytorch@af1e70a · GitHub
[go: up one dir, main page]

Skip to content

Commit af1e70a

Browse files
committed
Support whitelist of dynamic sources
ghstack-source-id: 2629127 Pull Request resolved: #147979
1 parent 5220d40 commit af1e70a

File tree

3 files changed

+72
-2
lines changed

3 files changed

+72
-2
lines changed

test/dynamo/test_misc.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7858,6 +7858,54 @@ def my_dyn_fn(x, y):
78587858
with self.assertRaises(ConstraintViolationError):
78597859
torch.compile(my_dyn_fn, backend="eager")(y, y)
78607860

7861+
@torch.compiler.config.patch(dynamic_sources="L['x']")
7862+
def test_dynamic_int_sources(self):
7863+
counter = CompileCounter()
7864+
7865+
@torch.compile(backend=counter)
7866+
def fn(x):
7867+
return torch.randn(5) * x
7868+
7869+
fn(1)
7870+
fn(2)
7871+
fn(3)
7872+
7873+
self.assertEqual(counter.frame_count, 1)
7874+
7875+
@torch.compiler.config.patch(dynamic_sources="L['x']")
7876+
def test_dynamic_tensor_sources(self):
7877+
counter = CompileCounter()
7878+
7879+
@torch.compile(backend=counter)
7880+
def fn(x):
7881+
return x * x
7882+
7883+
fn(torch.randn(2))
7884+
fn(torch.randn(3))
7885+
fn(torch.randn(4))
7886+
7887+
self.assertEqual(counter.frame_count, 1)
7888+
7889+
@torch.compiler.config.patch(dynamic_sources="L['x']")
7890+
def test_dynamic_sources_graph_break(self):
7891+
counter = CompileCounter()
7892+
7893+
def foo(x):
7894+
return x * x
7895+
7896+
@torch.compile(backend=counter)
7897+
def fn(x):
7898+
x = x * x
7899+
torch._dynamo.graph_break()
7900+
return foo(x)
7901+
7902+
fn(torch.randn(2))
7903+
fn(torch.randn(3))
7904+
fn(torch.randn(4))
7905+
7906+
# 2 since graph break produces 2 graphs. NB: there are no recompiles
7907+
self.assertEqual(counter.frame_count, 2)
7908+
78617909
def test_cannot_trace_mark_dynamic(self):
78627910
y = torch.randn([3, 3, 3])
78637911

torch/_dynamo/variables/builder.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@
275275
static_inputs_log = torch._logging.getArtifactLogger(
276276
__name__, "cudagraph_static_inputs"
277277
)
278+
_DYNAMIC_SOURCES: Optional[set[str]] = None
278279

279280

280281
DimList = list
@@ -1955,7 +1956,10 @@ def wrap_symint(self, value):
19551956
and frame_state_entry.scalar is auto_dynamic
19561957
):
19571958
dynamic_dim = get_automatic_dynamic_shapes_mark_as()
1958-
elif not config.assume_static_by_default:
1959+
elif (
1960+
not config.assume_static_by_default
1961+
or self.source.name() in get_dynamic_sources()
1962+
):
19591963
dynamic_dim = DimDynamic.DYNAMIC
19601964
else: # assume_static_by_default
19611965
# TODO: dynamic_dim = DimDynamic.STATIC should work but
@@ -2639,6 +2643,15 @@ def get_automatic_dynamic_shapes_mark_as():
26392643
)
26402644

26412645

2646+
def get_dynamic_sources() -> set[str]:
2647+
global _DYNAMIC_SOURCES
2648+
if _DYNAMIC_SOURCES is not None:
2649+
return _DYNAMIC_SOURCES
2650+
2651+
_DYNAMIC_SOURCES = set(torch.compiler.config.dynamic_sources.split(","))
2652+
return _DYNAMIC_SOURCES
2653+
2654+
26422655
# Tracks the sources of all fake tensors we wrap in Dynamo.
26432656
# Used by shape guard computation.
26442657
@dataclasses.dataclass
@@ -2669,6 +2682,7 @@ def _automatic_dynamic(
26692682
unimplemented("torch.compile does not support strided NestedTensor")
26702683

26712684
name = source.name()
2685+
dynamic_sources = get_dynamic_sources()
26722686
prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
26732687
shape_env_to_source_to_symbol_cache = (
26742688
prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None
@@ -2807,7 +2821,7 @@ def update_dim2constraint(dim, constraint_range, name):
28072821

28082822
# Reflect the user directive in the frame_state
28092823
# For dynamic, apply None always
2810-
if marked_dynamic:
2824+
if marked_dynamic or source.name() in dynamic_sources:
28112825
# TODO: This can be batched
28122826
# TODO: Doing this here is kind of sus, maybe better to set this
28132827
# up when we initially created the FrameStateSizeEntry to bong

torch/compiler/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,13 @@
6363
A common use case for such a tag is to break caches.
6464
"""
6565

66+
dynamic_sources: str = Config(
67+
env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default=""
68+
)
69+
"""
70+
Comma delimited list of sources that should be marked as dynamic. Primarily useful for large
71+
models with graph breaks where you need intermediate tensors and ints to be marked dynamic.
72+
"""
73+
6674

6775
install_config_module(sys.modules[__name__])

0 commit comments

Comments
 (0)
0