8000 [multigraph] fix composabilty with aotautograd cache (#153526) · pytorch/pytorch@e7bf72c · GitHub
[go: up one dir, main page]

Skip to content

Commit e7bf72c

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
[multigraph] fix composabilty with aotautograd cache (#153526)
AOTAutogradCache uses FXGraphCache which uses the tracing context to get the ShapeEnv. Although the TracingContext global_context is cleared by the time we get around to reusing it, we don't actually need it. We just need the ShapeEnv in the TracingContext, which isn't cleared at the end of dynamo and does persist. This PR adds the tracing context manager around the specialized compile to ensure our caching infrastructure can get access to the ShapeEnv. A test was also added to prove correctness. Pull Request resolved: #153526 Approved by: https://github.com/jamesjwu, https://github.com/zou3519 ghstack dependencies: #153433, #153449
1 parent 7183f52 commit e7bf72c

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

test/dynamo/test_aot_autograd_cache.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,49 @@ def fn(x, y):
272272
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
273273
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
274274

275+
@inductor_config.patch("fx_graph_remote_cache", False)
276+
@inductor_config.patch("fx_graph_cache", True)
277+
@functorch_config.patch({"enable_autograd_cache": True})
278+
def test_multi_graph_specialization(self):
279+
"""
280+
Verify multi graph specializations all cache hit
281+
"""
282+
283+
def fn(x):
284+
return x * 5
285+
286+
a = torch.randn(5)
287+
a8 = torch.randn(8)
288+
a16 = torch.randn(16)
289+
torch._dynamo.mark_dynamic(
290+
a,
291+
0,
292+
specialize_on=[
293+
lambda x: x == 8,
294+
lambda x: x == 16,
295+
],
296+
)
297+
298+
compiled_fn = torch.compile(fn, backend="inductor")
299+
300+
# A first call should miss in the cache.
301+
compiled_fn(a)
302+
compiled_fn(a8)
303+
compiled_fn(a16)
304+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 3)
305+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
306+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 3)
307+
308+
self._clear_dynamo_and_codecache()
309+
310+
# A second call should hit on all 3 graphs
311+
compiled_fn(a)
312+
compiled_fn(a8)
313+
compiled_fn(a16)
314+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 3)
315+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 3)
316+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 3)
317+
275318
@inductor_config.patch("fx_graph_remote_cache", False)
276319
@inductor_config.patch("fx_graph_cache", True)
277320
@functorch_config.patch({"enable_autograd_cache": True})

torch/_dynamo/output_graph.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
CompileId,
5252
GlobalContextCheckpointState,
5353
Source,
54+
tracing,
5455
TracingContext,
5556
)
5657
from torch._subclasses.fake_tensor import FakeTensor
@@ -1753,9 +1754,10 @@ def specialized_dispatch(*args, **kwargs):
17531754
# Modify gm so AOTAutogradCache key changes per specialization
17541755
gm.meta["specialization"] = specialization
17551756
example_inputs: list[Tensor] = list(args)
1756-
specialization_cache[specialization] = (
1757-
self.call_user_compiler(gm, example_inputs)
1758-
)
1757+
with tracing(self.tracing_context):
1758+
specialization_cache[specialization] = (
1759+
self.call_user_compiler(gm, example_inputs)
1760+
)
17591761

17601762
return specialization_cache[specialization](*args, **kwargs)
17611763
return compiled_fn(*args, **kwargs)

0 commit comments

Comments
 (0)
0