8000 Introduce unsafe way to mark functions as cacheable by oulgen · Pull Request #151603 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Introduce unsafe way to mark functions as cacheable #151603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions test/dynamo/test_aot_autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,81 @@ def fn(a):
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)

@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch(
{"enable_autograd_cache": True, "strict_autograd_cache": True}
)
@parametrize("fn_select", ("tag_activation_checkpoint", "allow_in_graph"))
def test_unsafe_mark_cacheable(self, fn_select):
if fn_select == "tag_activation_checkpoint":
from torch.utils.checkpoint import checkpoint

def gn(x, y, z=None):
a = torch.matmul(x, y)
if z is not None:
return torch.matmul(a, z)
return a

@torch.compile
def fn(x, y, z):
return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))

fn_name = "torch.ops.higher_order.tag_activation_checkpoint"
else:
assert fn_select == "allow_in_graph"

@torch._dynamo.allow_in_graph
class AllowInGraphFunc(torch.autograd.Function):
@staticmethod
def forward(_, x):
torch._dynamo.graph_break()
return x.sin()

@torch.compile
def fn(x, y, z):
return AllowInGraphFunc.apply(x)

fn_name = "torch._dynamo.variables.misc.trampoline_autograd_apply"

x = torch.randn(4, 4)
y = torch.randn(4, 4)
z = torch.randn(4, 4)
args = (x, y, z)

with self.assertRaisesRegex(
torch._dynamo.exc.BackendCompilerFailed,
r".*BypassAOTAutogradCache: Unsupported call_function target .*",
):
fn(*args)

self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)

self._clear_dynamo_and_codecache()

if fn_select == "allow_in_graph":
# TODO: Fix allow in graph
raise unittest.SkipTest(
"Allow in graph produces an unserializable cache artifact"
)

with inductor_config.patch("unsafe_marked_cacheable_functions", [fn_name]):
fn(*args)

self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)

self._clear_dynamo_and_codecache()

fn(*args)

self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)

@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", False)
@functorch_config.patch({"enable_autograd_cache": True})
Expand Down
2 changes: 2 additions & 0 deletions torch/_functorch/_aot_autograd/autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def is_safe_torch_function(target):
return (
function_name in torch_non_c_binding_in_graph_functions
or function_name in SAFE_TORCH_FUNCTIONS
or function_name in torch._inductor.config.unsafe_marked_cacheable_functions
)

def is_torch_function(target):
Expand Down Expand Up @@ -824,6 +825,7 @@ def load(
except Exception as e:
cache_key = None
counters["aot_autograd"]["autograd_cache_bypass"] += 1
log.info("Bypassing autograd cache due to: %s", e)
cache_state = "bypass"
cache_event_time = time.time_ns()
cache_info["cache_bypass_reason"] = str(e)
Expand Down
3 changes: 0 additions & 3 deletions torch/_functorch/autograd_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,5 @@ def backward(ctx, *grad):

return ApplyTemplate.apply(*new_fwd_args)

def cacheable(self):
return torch._functorch.config.autograd_cache_allow_custom_autograd_functions


autograd_function_apply = AutogradFunctionApply()
3 changes: 3 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def prologue_fusion_enabled() -> bool:
# Unsafe way to skip dynamic shape guards to get faster cache load
unsafe_skip_cache_dynamic_shape_guards: bool = False

# Unsafe way to mark function as cacheable
unsafe_marked_cacheable_functions: list[str] = []

# sleep in inductor for testing
sleep_sec_TESTING_ONLY: Optional[int] = None

Expand Down
17 changes: 14 additions & 3 deletions torch/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import inspect
import sys
import types
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, final, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Concatenate, ParamSpec

import torch
Expand Down Expand Up @@ -329,8 +329,19 @@ def maybe_run_autograd(*args: _P.args, **kwargs: _P.kwargs) -> _T:
def namespace(self):
return self._ns

def cacheable(self):
return self._cacheable
@final
def cacheable(self) -> bool:
from torch._functorch.autograd_function import AutogradFunctionApply

return (
self._cacheable
or f"{self.__module__}.{self.__name__}"
in torch._inductor.config.unsafe_marked_cacheable_functions
Copy link
Contributor
@bdhirsh bdhirsh Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought - I'm a bit worried about a situation like:

(1) end user is running torch.compile

(2) their compiled code uses some 3rd-party library code that (incorrectly) marks a bad function as cache-safe

(3) they get silent correctness and report it to us.

One obvious debugging step would be to ask them to find any of these functions that have been marked cache-safe and turn off caching for them, so we can easily tell if it's our fault (general caching bug) or someone else's fault (3rd party lib doing unsafe things).

This will be a pain to do in the current setup - since even if we tell the user to set unsafe_marked_cacheable_functions=[], the user doesn't have an easy way of ensuring that they can update this config last, after any 3rd party libs add to the config.

Given this setup - what do you think of adding (yet another) config to "ignore" these markings?

or (
isinstance(self, AutogradFunctionApply)
and torch._functorch.config.autograd_cache_allow_custom_autograd_functions
)
)

def fallthrough(self, dispatch_key):
self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)
Expand Down
Loading
0