8000 [aot cache][ca] remove restriction on caching ca's aot inference grap… · pytorch/pytorch@666508e · GitHub
[go: up one dir, main page]

Skip to content

Commit 666508e

Browse files
xmfanpytorchmergebot
authored andcommitted
[aot cache][ca] remove restriction on caching ca's aot inference graph (#148491)
but still can't cache CA's aot inference graph yet: the CA functional ops aren't serializable Pull Request resolved: #148491 Approved by: https://github.com/jamesjwu ghstack dependencies: #148381
1 parent c16cd25 commit 666508e

File tree

3 files changed

+10
-14
lines changed

3 files changed

+10
-14
lines changed

test/dynamo/test_aot_autograd_cache.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -300,26 +300,25 @@ def fn(x, y):
300300

301301
@inductor_config.patch("fx_graph_remote_cache", False)
302302
@inductor_config.patch("fx_graph_cache", True)
303-
@functorch_config.patch({"enable_autograd_cache": True})
303+
@functorch_config.patch(
304+
{"enable_autograd_cache": True, "strict_autograd_cache": True}
305+
)
304306
@dynamo_config.patch("compiled_autograd", True)
305307
def test_compiled_autograd_bypass(self):
308+
# Need to make the compiled autograd graph serializable
306309
def fn(a, b):
307310
out = a.cos() + b
308311
loss = out.sum()
309312
ga, gb = torch.autograd.grad(loss, inputs=[a, b])
310313

311314
a = torch.randn(25, requires_grad=True)
312315
b = torch.randn(25, requires_grad=True)
313-
a2 = a.detach().clone().requires_grad_(True)
314-
b2 = b.detach().clone().requires_grad_(True)
315316
compiled_fn = torch.compile(fn, backend="inductor")
316-
self.assertEqual(fn(a, b), compiled_fn(a2, b2))
317-
self.assertEqual(
318-
counters["aot_autograd"]["autograd_cache_miss"], 1
319-
) # from compiled forward
320-
self.assertEqual< 10000 /span>(
321-
counters["aot_autograd"]["autograd_cache_bypass"], 1
322-
) # from compiled autograd
317+
with self.assertRaisesRegex(
318+
torch._dynamo.exc.BackendCompilerFailed,
319+
"BypassAOTAutogradCache: Unsupported call_function target torch._dynamo.compiled_autograd.ops.validate_outputs",
320+
):
321+
compiled_fn(a, b)
323322

324323
@inductor_config.patch("fx_graph_remote_cache", False)
325324
@inductor_config.patch("fx_graph_cache", True)

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,6 @@ def check_cacheable(gm: torch.fx.GraphModule):
207207
Checks that the graph module only uses supported operators
208208
"""
209209
nodes = gm.graph.nodes
210-
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
211-
raise BypassAOTAutogradCache(
212-
"Cannot cache a graph with compiled autograd enabled"
213-
)
214210
if torch._inductor.config.freezing:
215211
raise BypassAOTAutogradCache("Cannot cache a graph with freezing enabled")
216212

torch/csrc/autograd/python_function.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace torch::autograd {
2424

2525
// A Function which is implemented by a Python object (i.e., a THPFunction).
2626
// Calls to 'apply' are forwarded to the Python method implementation.
27+
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
2728
struct PyNode : public Node {
2829
PyNode(THPObjectPtr obj) : obj(obj.release()) {}
2930

0 commit comments

Comments
 (0)
0