8000 WIP: Fix caching when output has unbacked · pytorch/pytorch@ae66b2f · GitHub
[go: up one dir, main page]

Skip to content

Commit ae66b2f

Browse files
committed
WIP: Fix caching when output has unbacked
ghstack-source-id: d47bdf5 Pull Request resolved: #153034
1 parent a28dcdb commit ae66b2f

File tree

3 files changed

+138
-48
lines changed

3 files changed

+138
-48
lines changed

test/test_fake_tensor.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,13 +2265,10 @@ def count_invoke_subgraph_keys():
22652265
gc.collect()
22662266
self.assertTrue(count_invoke_subgraph_keys() == 0)
22672267

2268-
2269-
22702268
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
22712269
def test_invoke_subgraph_cacheable_inplace(self):
22722270
invoke_subgraph = torch._higher_order_ops.invoke_subgraph
22732271

2274-
22752272
def fn(x, y):
22762273
# aten ops are used so that eager backend graph is suitable for fake
22772274
# tensor testing
@@ -2317,5 +2314,31 @@ def fn(x, y):
23172314
extract_tensor_metadata(b),
23182315
)
23192316

2317+
def test_unbacked_output(self):
2318+
# The point of this test is to have an op which has no symbols as input
2319+
# but a symbol as an output and make sure that we skip caching it.
2320+
class LengthsGather(torch.nn.Module):
2321+
def forward(
2322+
self,
2323+
input: torch.Tensor,
2324+
lengths: torch.Tensor,
2325+
indices: torch.Tensor,
2326+
offsets: torch.Tensor,
2327+
) -> torch.Tensor:
2328+
bias = torch.gather(offsets, 0, indices)
2329+
lengths_selected = torch.gather(lengths, 0, indices)
2330+
index = torch.repeat_interleave(bias, lengths_selected, dim=0)
2331+
return index
2332+
2333+
input = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
2334+
lengths = torch.tensor([0, 2, 3, 1, 4])
2335+
indices = torch.tensor([2, 3, 4, 6, 7, 8, 9])
2336+
offsets = torch.cumsum(lengths, 0)
2337+
ep = torch.export.export(LengthsGather(), (input, lengths, indices, offsets), strict=False)
2338+
2339+
FakeTensorMode.cache_clear()
2340+
ep.run_decompositions({})
2341+
self.assertBypasses("unrepresented symbol in output", 2)
2342+
23202343
if __name__ == "__main__":
23212344
run_tests()

torch/_subclasses/_fake_tensor_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ class _CacheKeyState:
217217
# We track the SymNodes so when we get the output we can see if it exactly
218218
# matches one of the inputs so we can uncache it properly.
219219
sym_node_lookup: dict[int, int] # id(SymNode) -> index
220+
known_symbols: set[sympy.Symbol]
220221

221222
# There are cases where we're asked to perform an op when we have no
222223
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
@@ -226,6 +227,7 @@ class _CacheKeyState:
226227

227228
def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
228229
self.sym_node_lookup = {}
230+
self.known_symbols = set()
229231
self.shape_env = shape_env
230232

231233
def cache_on_shape_env(self) -> bool:
@@ -247,6 +249,8 @@ def convert_sym_int(self, result: list[object], arg: SymInt) -> None:
247249
result.append(_InputBackref(self.sym_node_lookup[node_id]))
248250
else:
249251
self.sym_node_lookup[node_id] = len(result)
252+
for symbol in arg.node.expr.free_symbols:
253+
self.known_symbols.add(symbol)
250254
if self.shape_env is None:
251255
self.shape_env = arg.node.shape_env
252256
result.append(_PySymInputStub(arg))

torch/_subclasses/fake_tensor.py

Lines changed: 108 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,6 @@
7474
raise e
7575

7676

77-
class _Unassigned:
78-
pass
79-
80-
81-
_UNASSIGNED = _Unassigned()
82-
8377
DimList = list
8478

8579
pytree = torch.utils._pytree
@@ -1118,7 +1112,7 @@ class _DispatchCacheEntryOutputInfo:
11181112

11191113
@dataclass_slots
11201114
@dataclass(frozen=True)
1121-
class _DispatchCacheEntry:
1115+
class _DispatchCacheValidEntry:
11221116
"""
11231117
Entry type for the FakeTensor dispatch cache. It supports two types of outputs
11241118
1) tensor
@@ -1131,6 +1125,20 @@ class _DispatchCacheEntry:
11311125
is_output_tuple: bool = False
11321126

11331127

1128+
@dataclass_slots
1129+
@dataclass(frozen=True)
1130+
class _DispatchCacheBypassEntry:
1131+
"""
1132+
Entry type for a negative cache entry.
1133+
"""
1134+
1135+
reason: str
1136+
1137+
1138+
if TYPE_CHECKING:
1139+
_DispatchCacheEntry = Union[_DispatchCacheValidEntry, _DispatchCacheBypassEntry]
1140+
1141+
11341142
@dataclass_slots
11351143
@dataclass(frozen=True)
11361144
class _BypassDispatchCache(Exception):
@@ -1418,37 +1426,64 @@ def _cached_dispatch_impl(
14181426
Lookup a cache entry for the given arguments. If none exists, dispatch
14191427
and cache the result (if the result is eligible for caching).
14201428
"""
1421-
output: object = _UNASSIGNED
14221429
try:
14231430
state = _CacheKeyState(self.shape_env)
14241431
key = self._cache_key(state, func, args, kwargs)
1425-
if state.cache_on_shape_env():
1426-
assert state.shape_env is not None
1427-
cache = state.shape_env.fake_tensor_cache
1428-
else:
1429-
cache = FakeTensorMode.cache
1430-
entry = cache.get(key, None)
1431-
if entry is not None:
1432-
output = self._output_from_cache_entry(state, entry, key, func, args)
1433-
FakeTensorMode.cache_hits += 1
1434-
if self.cache_crosscheck_enabled:
1435-
# For debugging / testing: Validate that the output synthesized
1436-
# from the cache matches the output created by normal dispatch.
1437-
with disable_fake_tensor_cache(self):
1438-
self._crosscheck_cache_output(output, func, types, args, kwargs)
1439-
else:
1440-
self._validate_cache_key(func, args, kwargs)
1441-
output = self._dispatch_impl(func, types, args, kwargs)
1442-
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
1443-
key.strip_shape_env()
1444-
cache[key] = entry
1445-
FakeTensorMode.cache_misses += 1
14461432
except _BypassDispatchCache as e:
1433+
# We couldn't create the cache key at all
1434+
FakeTensorMode.cache_bypasses[e.reason] += 1
1435+
return self._dispatch_impl(func, types, args, kwargs)
1436+
1437+
if state.cache_on_shape_env():
1438+
assert state.shape_env is not None
1439+
cache = state.shape_env.fake_tensor_cache
1440+
set_cache_key = _set_cache_key_for_shape_env
1441+
else:
1442+
cache = FakeTensorMode.cache
1443+
set_cache_key = _set_cache_key
1444+
entry = cache.get(key, None)
1445+
1446+
if entry is not None:
1447+
if isinstance(entry, _DispatchCacheBypassEntry):
1448+
# This represents a negative cache entry - we already saw that the
1449+
# output is uncachable. Compute it from first principals.
1450+
FakeTensorMode.cache_bypasses[entry.reason] += 1
1451+
return self._dispatch_impl(func, types, args, kwargs)
1452+
1453+
# We have a cache entry.
1454+
output = self._output_from_cache_entry(state, entry, key, func, args)
1455+
FakeTensorMode.cache_hits += 1
1456+
if self.cache_crosscheck_enabled:
1457+
# For debugging / testing: Validate that the output synthesized
1458+
# from the cache matches the output created by normal dispatch.
1459+
with disable_fake_tensor_cache(self):
1460+
self._crosscheck_cache_output(output, func, types, args, kwargs)
1461+
return output
1462+
1463+
# We don't have a cache entry.
1464+
output = self._dispatch_impl(func, types, args, kwargs)
1465+
1466+
try:
1467+
self._validate_cache_key(func, args, kwargs)
1468+
except _BypassDispatchCache as e:
1469+
# We ran "extra" checks on the cache key and determined that it's no
1470+
# good. Record the reason and mark it so we don't bother validating
1471+
# again.
14471472
FakeTensorMode.cache_bypasses[e.reason] += 1
1473+
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
1474+
return output
14481475

1449-
if output is _UNASSIGNED:
1450-
output = self._dispatch_impl(func, types, args, kwargs)
1476+
try:
1477+
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
1478+
except _BypassDispatchCache as e:
1479+
# We had trouble making the cache entry. Record the reason and mark
1480+
# it.
1481+
FakeTensorMode.cache_bypasses[e.reason] += 1
1482+
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
1483+
return output
14511484

1485+
set_cache_key(cache, key, entry)
1486+
FakeTensorMode.cache_misses += 1
14521487
return output
14531488

14541489
def _cache_key(
@@ -1634,17 +1669,15 @@ def _validate_output_for_cache_entry(
16341669
kwargs: Mapping[str, object],
16351670
output: Optional[FakeTensor],
16361671
) -> None:
1637-
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
1638-
1672+
# Is this even possible?
16391673
if isinstance(output, (int, type(None))):
16401674
return
16411675

1642-
if isinstance(output, torch.SymInt):
1643-
if has_free_unbacked_symbols(output):
1644-
# This is unreachable but adding the check for extra safety in
1645-
# case we change code path in future.
1646-
raise _BypassDispatchCache("unbacked symbol in output")
1647-
return
1676+
if _has_unrepresented_symbols(state, output):
1677+
# Unbacked symbols are fine - but only if they're also represented
1678+
# in the input. If there are any new unbacked symbols then we can't
1679+
# cache this output.
1680+
raise _BypassDispatchCache("unrepresented symbol in output")
16481681

16491682
# Some ops return tuples of Tensors, but it's rare, so avoid
16501683
# the complexity of caching other types.
@@ -1718,7 +1751,7 @@ def _get_output_info_for_cache_entry(
17181751
# we can synthesize a tensor here and do the checks on that instance.
17191752
# This approach keeps the (more frequent) cache-hit path as lightweight
17201753
# as possible.
1721-
entry_for_synth_output = _DispatchCacheEntry(
1754+
entry_for_synth_output = _DispatchCacheValidEntry(
17221755
output_infos=(entry,), is_output_tuple=False
17231756
)
17241757
synth_output = self._output_from_cache_entry(
@@ -1742,7 +1775,7 @@ def _make_cache_entry(
17421775
args: Sequence[object],
17431776
kwargs: Mapping[str, object],
17441777
output: Optional[FakeTensor],
1745-
) -> _DispatchCacheEntry:
1778+
) -> _DispatchCacheValidEntry:
17461779
"""
17471780
Make a cache entry object for the given 'output' Tensor. Raises
17481781
_BypassDispatchCache if the output tensor has characteristics that
@@ -1773,7 +1806,7 @@ def _make_cache_entry(
17731806
output_info = _DispatchCacheEntryOutputInfo(
17741807
inplace_idx=None, metadata=None, view_idx=None, constant_value=output
17751808
)
1776-
return _DispatchCacheEntry(
1809+
return _DispatchCacheValidEntry(
17771810
output_infos=(output_info,), is_output_tuple=False
17781811
)
17791812

@@ -1794,15 +1827,15 @@ def _make_cache_entry(
17941827
)
17951828
for out_elem in output
17961829
]
1797-
return _DispatchCacheEntry(
1830+
return _DispatchCacheValidEntry(
17981831
output_infos=tuple(output_infos), is_output_tuple=True
17991832
)
18001833

18011834
else:
18021835
output_info = self._get_output_info_for_cache_entry(
18031836
state, key, func, args, kwargs, output
18041837
)
1805-
return _DispatchCacheEntry(
1838+
return _DispatchCacheValidEntry(
18061839
output_infos=(output_info,), is_output_tuple=False
18071840
)
18081841

@@ -1882,7 +1915,7 @@ def check_value(
18821915
def _output_from_cache_entry(
18831916
self,
18841917
state: _CacheKeyState,
1885-
entry: _DispatchCacheEntry,
1918+
entry: _DispatchCacheValidEntry,
18861919
key: _DispatchCacheKey,
18871920
func: OpOverload,
18881921
args: Sequence[object],
@@ -2886,6 +2919,19 @@ def from_tensor(
28862919
_StoragePointer = object
28872920

28882921

2922+
def _has_unrepresented_symbols(
2923+
state: _CacheKeyState, output: Optional[FakeTensor]
2924+
) -> bool:
2925+
from torch.fx.experimental.symbolic_shapes import _iterate_exprs
2926+
2927+
for s in _iterate_exprs(output):
2928+
for symbol in s.free_symbols:
2929+
if symbol not in state.known_symbols:
2930+
return True
2931+
2932+
return False
2933+
2934+
28892935
# NB: returns fake tensors
28902936
def run_fallback_kernel(
28912937
fake_mode: FakeTensorMode,
@@ -2951,6 +2997,23 @@ def map_out(e: T) -> Union[T, FakeTensor]:
29512997
return pytree.tree_map(map_out, r)
29522998

29532999

3000+
def _set_cache_key_for_shape_env(
3001+
cache: dict[_DispatchCacheKey, _DispatchCacheEntry],
3002+
key: _DispatchCacheKey,
3003+
entry: _DispatchCacheEntry,
3004+
) -> None:
3005+
key.strip_shape_env()
3006+
cache[key] = entry
3007+
3008+
3009+
def _set_cache_key(
3010+
cache: dict[_DispatchCacheKey, _DispatchCacheEntry],
3011+
key: _DispatchCacheKey,
3012+
entry: _DispatchCacheEntry,
3013+
) -> None:
3014+
cache[key] = entry
3015+
3016+
29543017
# Just for use to allow copying a module to fake tensors,
29553018
# does not apply elsewhere
29563019
class FakeCopyMode(TorchFunctionMode):

0 commit comments

Comments
 (0)
0