8000 Introduce unsafe way to mark functions as cacheable (#151603) · pytorch/pytorch@0f8613b · GitHub
[go: up one dir, main page]

Skip to content

Commit 0f8613b

Browse files
oulgenpytorchmergebot
authored andcommitted
Introduce unsafe way to mark functions as cacheable (#151603)
Pull Request resolved: #151603 Approved by: https://github.com/jamesjwu ghstack dependencies: #151768, #151609
1 parent 67c2869 commit 0f8613b

File tree

5 files changed

+94
-6
lines changed

5 files changed

+94
-6
lines changed

test/dynamo/test_aot_autograd_cache.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,81 @@ def fn(a):
309309
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
310310
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
311311

312+
@inductor_config.patch("fx_graph_remote_cache", False)
313+
@inductor_config.patch("fx_graph_cache", True)
314+
@functorch_config.patch(
315+
{"enable_autograd_cache": True, "strict_autograd_cache": True}
316+
)
317+
@parametrize("fn_select", ("tag_activation_checkpoint", "allow_in_graph"))
318+
def test_unsafe_mark_cacheable(self, fn_select):
319+
if fn_select == "tag_activation_checkpoint":
320+
from torch.utils.checkpoint import checkpoint
321+
322+
def gn(x, y, z=None):
323+
a = torch.matmul(x, y)
324+
if z is not None:
325+
return torch.matmul(a, z)
326+
return a
327+
328+
@torch.compile
329+
def fn(x, y, z):
330+
return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
331+
332+
fn_name = "torch.ops.higher_order.tag_activation_checkpoint"
333+
else:
334+
assert fn_select == "allow_in_graph"
335+
336+
@torch._dynamo.allow_in_graph
337+
class AllowInGraphFunc(torch.autograd.Function):
338+
@staticmethod
339+
def forward(_, x):
340+
torch._dynamo.graph_break()
341+
return x.sin()
342+
343+
@torch.compile
344+
def fn(x, y, z):
345+
return AllowInGraphFunc.apply(x)
346+
347+
fn_name = "torch._dynamo.variables.misc.trampoline_autograd_apply"
348+
349+
x = torch.randn(4, 4)
350+
y = torch.randn(4, 4)
351+
z = torch.randn(4, 4)
352+
args = (x, y, z)
353+
354+
with self.assertRaisesRegex(
355+
torch._dynamo.exc.BackendCompilerFailed,
356+
r".*BypassAOTAutogradCache: Unsupported call_function target .*",
357+
):
358+
fn(*args)
359+
360+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0)
361+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
362+
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
363+
364+
self._clear_dynamo_and_codecache()
365+
366+
if fn_select == "allow_in_graph":
367+
# TODO: Fix allow in graph
368+
raise unittest.SkipTest(
369+
"Allow in graph produces an unserializable cache artifact"
370+
)
371+
372+
with inductor_config.patch("unsafe_marked_cacheable_functions", [fn_name]):
373+
fn(*args)
374+
375+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
376+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
377+
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
378+
379+
self._clear_dynamo_and_codecache()
380+
381+
fn(*args)
382+
383+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
384+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
385+
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
386+
312387
@inductor_config.patch("fx_graph_remote_cache", False)
313388
@inductor_config.patch("fx_graph_cache", False)
314389
@functorch_config.patch({"enable_autograd_cache": True})

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def is_safe_torch_function(target):
157157
return (
158158
function_name in torch_non_c_binding_in_graph_functions
159159
or function_name in SAFE_TORCH_FUNCTIONS
160+
or function_name in torch._inductor.config.unsafe_marked_cacheable_functions
160161
)
161162

162163
def is_torch_function(target):
@@ -824,6 +825,7 @@ def load(
824825
except Exception as e:
825826
cache_key = None
826827
counters["aot_autograd"]["autograd_cache_bypass"] += 1
828+
log.info("Bypassing autograd cache due to: %s", e)
827829
cache_state = "bypass"
828830
cache_event_time = time.time_ns()
829831
cache_info["cache_bypass_reason"] = str(e)

torch/_functorch/autograd_function.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -773,8 +773,5 @@ def backward(ctx, *grad):
773773

774774
return ApplyTemplate.apply(*new_fwd_args)
775775

776-
def cacheable(self):
777-
return torch._functorch.config.autograd_cache_allow_custom_autograd_functions
778-
779776

780777
autograd_function_apply = AutogradFunctionApply()

torch/_inductor/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def prologue_fusion_enabled() -> bool:
124124
# Unsafe way to skip dynamic shape guards to get faster cache load
125125
unsafe_skip_cache_dynamic_shape_guards: bool = False
126126

127+
# Unsafe way to mark function as cacheable
128+
unsafe_marked_cacheable_functions: list[str] = []
129+
127130
# sleep in inductor for testing
128131
sleep_sec_TESTING_ONLY: Optional[int] = None
129132

torch/_ops.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import inspect
77
import sys
88
import types
9-
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
9+
from typing import Any, Callable, final, Optional, TYPE_CHECKING, TypeVar, Union
1010
from typing_extensions import Concatenate, ParamSpec
1111

1212
import torch
@@ -329,8 +329,19 @@ def maybe_run_autograd(*args: _P.args, **kwargs: _P.kwargs) -> _T:
329329
def namespace(self):
330330
return self._ns
331331

332-
def cacheable(self):
333-
return self._cacheable
332+
@final
333+
def cacheable(self) -> bool:
334+
from torch._functorch.autograd_function import AutogradFunctionApply
335+
336+
return (
337+
self._cacheable
338+
or f"{self.__module__}.{self.__name__}"
339+
in torch._inductor.config.unsafe_marked_cacheable_functions
340+
or (
341+
isinstance(self, AutogradFunctionApply)
342+
and torch._functorch.config.autograd_cache_allow_custom_autograd_functions
343+
)
344+
)
334345

335346
def fallthrough(self, dispatch_key):
336347
self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)

0 commit comments

Comments
 (0)
0