8000 FakeTensorMode dispatch shouldn't include bypass in exception context… · pytorch/pytorch@aa84c03 · GitHub
[go: up one dir, main page]

Skip to content

Commit aa84c03

Browse files
aorenstepytorchmergebot
authored andcommitted
FakeTensorMode dispatch shouldn't include bypass in exception context (#153780)
In the FakeTensor cache when we get a bypass exception while computing the cache key (call this exc_1) we need to dispatch to the original operation. It's possible for the dispatch to the original operation to get its own exception which we want to bubble up to the caller (call this exc_2). If we directly dispatch from within the handler for exc_1 then exc_2 will have a `__context__` of exc_1 - which can cause deviations between cached and non-cached behavior - so we need to be a bit careful when we call the dispatch. Testing: test_aotdispatch.py::TestAOTExport::test_aot_export_predispatch_outdtype fails before this change and passes after. Pull Request resolved: #153780 Approved by: https://github.com/oulgen
1 parent 6803419 commit aa84c03

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

torch/_subclasses/fake_tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,14 +1426,22 @@ def _cached_dispatch_impl(
14261426
Lookup a cache entry for the given arguments. If none exists, dispatch
14271427
and cache the result (if the result is eligible for caching).
14281428
"""
1429+
state = None
1430+
key = None
14291431
try:
14301432
state = _CacheKeyState(self.shape_env)
14311433
key = self._cache_key(state, func, args, kwargs)
14321434
except _BypassDispatchCache as e:
14331435
# We couldn't create the cache key at all
14341436
FakeTensorMode.cache_bypasses[e.reason] += 1
1437+
1438+
if key is None:
1439+
# Do this dispatch outside the above except handler so if it
1440+
# generates its own exception there won't be a __context__ caused by
1441+
# the caching mechanism.
14351442
return self._dispatch_impl(func, types, args, kwargs)
14361443

1444+
assert state is not None
14371445
if state.cache_on_shape_env():
14381446
assert state.shape_env is not None
14391447
cache = state.shape_env.fake_tensor_cache

0 commit comments

Comments
 (0)
0