8000 WIP: Fix caching when output has unbacked · pytorch/pytorch@e59e26a · GitHub
[go: up one dir, main page]

Skip to content

Commit e59e26a

Browse files
committed
WIP: Fix caching when output has unbacked
ghstack-source-id: 4500c0a Pull Request resolved: #153034
1 parent a28dcdb commit e59e26a

File tree

3 files changed

+101
-32
lines changed

3 files changed

+101
-32
lines changed

torch/_subclasses/_fake_tensor_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ class _CacheKeyState:
217217
# We track the SymNodes so when we get the output we can see if it exactly
218218
# matches one of the inputs so we can uncache it properly.
219219
sym_node_lookup: dict[int, int] # id(SymNode) -> index
220+
known_symbols: set[sympy.Basic]
220221

221222
# There are cases where we're asked to perform an op when we have no
222223
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
@@ -226,6 +227,7 @@ class _CacheKeyState:
226227

227228
def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
228229
self.sym_node_lookup = {}
230+
self.known_symbols = set()
229231
self.shape_env = shape_env
230232

231233
def cache_on_shape_env(self) -> bool:
@@ -247,6 +249,7 @@ def convert_sym_int(self, result: list[object], arg: SymInt) -> None:
247249
result.append(_InputBackref(self.sym_node_lookup[node_id]))
248250
else:
249251
self.sym_node_lookup[node_id] = len(result)
252+
self.known_symbols.add(arg.node.expr)
250253
if self.shape_env is None:
251254
self.shape_env = arg.node.shape_env
252255
result.append(_PySymInputStub(arg))

torch/_subclasses/fake_tensor.py

Lines changed: 97 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,7 @@ class DispatchCacheInfo:
11641164

11651165

11661166
class FakeTensorMode(TorchDispatchMode):
1167-
cache: dict[_DispatchCacheKey, _DispatchCacheEntry] = {}
1167+
cache: dict[_DispatchCacheKey, Optional[_DispatchCacheEntry]] = {}
11681168
cache_hits: int = 0
11691169
cache_misses: int = 0
11701170
cache_bypasses: dict[str, int] = defaultdict(int)
@@ -1418,33 +1418,65 @@ def _cached_dispatch_impl(
14181418
Lookup a cache entry for the given arguments. If none exists, dispatch
14191419
and cache the result (if the result is eligible for caching).
14201420
"""
1421-
output: object = _UNASSIGNED
14221421
try:
14231422
state = _CacheKeyState(self.shape_env)
14241423
key = self._cache_key(state, func, args, kwargs)
1425-
if state.cache_on_shape_env():
1426-
assert state.shape_env is not None
1427-
cache = state.shape_env.fake_tensor_cache
1428-
else:
1429-
cache = FakeTensorMode.cache
1430-
entry = cache.get(key, None)
1431-
if entry is not None:
1432-
output = self._output_from_cache_entry(state, entry, key, func, args)
1433-
FakeTensorMode.cache_hits += 1
1434-
if self.cache_crosscheck_enabled:
1435-
# For debugging / testing: Validate that the output synthesized
1436-
# from the cache matches the output created by normal dispatch.
1437-
with disable_fake_tensor_cache(self):
1438-
self._crosscheck_cache_output(output, func, types, args, kwargs)
1439-
else:
1440-
self._validate_cache_key(func, args, kwargs)
1441-
output = self._dispatch_impl(func, types, args, kwargs)
1442-
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
1443-
key.strip_shape_env()
1444-
cache[key] = entry
1445-
FakeTensorMode.cache_misses += 1
14461424
except _BypassDispatchCache as e:
1425+
# We couldn't create the cache key at all
1426+
FakeTensorMode.cache_bypasses[e.reason] += 1
1427+
return self._dispatch_impl(func, types, args, kwargs)
1428+
1429+
if state.cache_on_shape_env():
1430+
assert state.shape_env is not None
1431+
cache = state.shape_env.fake_tensor_cache
1432+
set_cache_key = _set_cache_key_for_shape_env
1433+
else:
1434+
cache = FakeTensorMode.cache
1435+
set_cache_key = _set_cache_key
1436+
entry = cache.get(key, _UNASSIGNED)
1437+
1438+
if entry is None:
1439+
# This represents a negative cache entry - we already saw that the
1440+
# output is uncachable. Compute it from first principals.
1441+
return self._dispatch_impl(func, types, args, kwargs)
1442+
1443+
if entry is not _UNASSIGNED:
1444+
# We have a cache entry.
1445+
if TYPE_CHECKING:
1446+
assert isinstance(entry, _DispatchCacheEntry)
1447+
output = self._output_from_cache_entry(state, entry, key, func, args)
1448+
FakeTensorMode.cache_hits += 1
1449+
if self.cache_crosscheck_enabled:
1450+
# For debugging / testing: Validate that the output synthesized
1451+
# from the cache matches the output created by normal dispatch.
1452+
with disable_fake_tensor_cache(self):
1453+
self._crosscheck_cache_output(output, func, types, args, kwargs)
1454+
return output
1455+
1456+
# We don't have a cache entry.
1457+
try:
1458+
self._validate_cache_key(func, args, kwargs)
1459+
except _BypassDispatchCache as e:
1460+
# We ran "extra" checks on the cache key and determined that it's no
1461+
# good. Record the reason and mark it so we don't bother validating
1462+
# again.
1463+
FakeTensorMode.cache_bypasses[e.reason] += 1
1464+
set_cache_key(cache, key, None)
1465+
return self._dispatch_impl(func, types, args, kwargs)
1466+
1467+
output = self._dispatch_impl(func, types, args, kwargs)
1468+
try:
1469+
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
1470+
except _BypassDispatchCache as e:
1471+
# We had trouble making the cache entry. Record the reason and mark
1472+
# it - but note that at this point we DO have a valid output already
1473+
# so no reason to recompute it.
14471474
FakeTensorMode.cache_bypasses[e.reason] += 1
1475+
set_cache_key(cache, key, None)
1476+
return output
1477+
1478+
set_cache_key(cache, key, entry)
1479+
FakeTensorMode.cache_misses += 1
14481480

14491481
if output is _UNASSIGNED:
14501482
output = self._dispatch_impl(func, types, args, kwargs)
@@ -1634,17 +1666,15 @@ def _validate_output_for_cache_entry(
16341666
kwargs: Mapping[str, object],
16351667
output: Optional[FakeTensor],
16361668
) -> None:
1637-
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
1638-
1669+
# Is this even possible?
16391670
if isinstance(output, (int, type(None))):
16401671
return
16411672

1642-
if isinstance(output, torch.SymInt):
1643-
if has_free_unbacked_symbols(output):
1644-
# This is unreachable but adding the check for extra safety in
1645-
# case we change code path in future.
1646-
raise _BypassDispatchCache("unbacked symbol in output")
1647-
return
1673+
if _has_unrepresented_unbacked_symbols(state, output):
1674+
# Unbacked symbols are fine - but only if they're also represented
1675+
# in the input. If there are any new unbacked symbols then we can't
1676+
# cache this output.
1677+
raise _BypassDispatchCache("unbacked symbol in output")
16481678

16491679
# Some ops return tuples of Tensors, but it's rare, so avoid
16501680
# the complexity of caching other types.
@@ -2886,6 +2916,25 @@ def from_tensor(
28862916
_StoragePointer = object
28872917

28882918

2919+
def _has_unrepresented_unbacked_symbols(
2920+
state: _CacheKeyState, output: Optional[FakeTensor]
2921+
) -> bool:
2922+
from sympy.core.traversal import iterargs
2923+
2924+
from torch.fx.experimental.symbolic_shapes import _iterate_exprs
2925+
from torch.utils._sympy.symbol import symbol_is_type, SymT
2926+
2927+
for s in _iterate_exprs(output):
2928+
for arg in iterargs(s):
2929+
if arg.is_Symbol and symbol_is_type(
2930+
arg, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT)
2931+
):
2932+
if arg not in state.known_symbols:
2933+
return True
2934+
2935+
return False
2936+
2937+
28892938
# NB: returns fake tensors
28902939
def run_fallback_kernel(
28912940
fake_mode: FakeTensorMode,
@@ -2951,6 +3000,23 @@ def map_out(e: T) -> Union[T, FakeTensor]:
29513000
return pytree.tree_map(map_out, r)
29523001

29533002

3003+
def _set_cache_key_for_shape_env(
3004+
cache: dict[_DispatchCacheKey, Optional[_DispatchCacheEntry]],
3005+
key: _DispatchCacheKey,
3006+
entry: Optional[_DispatchCacheEntry],
3007+
) -> None:
3008+
key.strip_shape_env()
3009+
cache[key] = entry
3010+
3011+
3012+
def _set_cache_key(
3013+
cache: dict[_DispatchCacheKey, Optional[_DispatchCacheEntry]],
3014+
key: _DispatchCacheKey,
3015+
entry: Optional[_DispatchCacheEntry],
3016+
) -> None:
3017+
cache[key] = entry
3018+
3019+
29543020
# Just for use to allow copying a module to fake tensors,
29553021
# does not apply elsewhere
29563022
class FakeCopyMode(TorchFunctionMode):

torch/fx/experimental/symbolic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3307,7 +3307,7 @@ def __init__(
33073307
# with the GC.
33083308
self.fake_tensor_cache: dict[
33093309
torch._subclasses.fake_tensor._DispatchCacheKey,
3310-
torch._subclasses.fake_tensor._DispatchCacheEntry,
3310+
Optional[torch._subclasses.fake_tensor._DispatchCacheEntry],
33113311
] = {}
33123312

33133313
@contextlib.contextmanager

0 commit comments

Comments
 (0)
0