@@ -1164,7 +1164,7 @@ class DispatchCacheInfo:
1164
1164
1165
1165
1166
1166
class FakeTensorMode (TorchDispatchMode ):
1167
- cache : dict [_DispatchCacheKey , _DispatchCacheEntry ] = {}
1167
+ cache : dict [_DispatchCacheKey , Optional [ _DispatchCacheEntry ] ] = {}
1168
1168
cache_hits : int = 0
1169
1169
cache_misses : int = 0
1170
1170
cache_bypasses : dict [str , int ] = defaultdict (int )
@@ -1418,33 +1418,65 @@ def _cached_dispatch_impl(
1418
1418
Lookup a cache entry for the given arguments. If none exists, dispatch
1419
1419
and cache the result (if the result is eligible for caching).
1420
1420
"""
1421
- output : object = _UNASSIGNED
1422
1421
try :
1423
1422
state = _CacheKeyState (self .shape_env )
1424
1423
key = self ._cache_key (state , func , args , kwargs )
1425
- if state .cache_on_shape_env ():
1426
- assert state .shape_env is not None
1427
- cache = state .shape_env .fake_tensor_cache
1428
- else :
1429
- cache = FakeTensorMode .cache
1430
- entry = cache .get (key , None )
1431
- if entry is not None :
1432
- output = self ._output_from_cache_entry (state , entry , key , func , args )
1433
- FakeTensorMode .cache_hits += 1
1434
- if self .cache_crosscheck_enabled :
1435
- # For debugging / testing: Validate that the output synthesized
1436
- # from the cache matches the output created by normal dispatch.
1437
- with disable_fake_tensor_cache (self ):
1438
- self ._crosscheck_cache_output (output , func , types , args , kwargs )
1439
- else :
1440
- self ._validate_cache_key (func , args , kwargs )
1441
- output = self ._dispatch_impl (func , types , args , kwargs )
1442
- entry = self ._make_cache_entry (state , key , func , args , kwargs , output )
1443
- key .strip_shape_env ()
1444
- cache [key ] = entry
1445
- FakeTensorMode .cache_misses += 1
1446
1424
except _BypassDispatchCache as e :
1425
+ # We couldn't create the cache key at all
1426
+ FakeTensorMode .cache_bypasses [e .reason ] += 1
1427
+ return self ._dispatch_impl (func , types , args , kwargs )
1428
+
1429
+ if state .cache_on_shape_env ():
1430
+ assert state .shape_env is not None
1431
+ cache = state .shape_env .fake_tensor_cache
1432
+ set_cache_key = _set_cache_key_for_shape_env
1433
+ else :
1434
+ cache = FakeTensorMode .cache
1435
+ set_cache_key = _set_cache_key
1436
+ entry = cache .get (key , _UNASSIGNED )
1437
+
1438
+ if entry is None :
1439
+ # This represents a negative cache entry - we already saw that the
1440
+ # output is uncachable. Compute it from first principals.
1441
+ return self ._dispatch_impl (func , types , args , kwargs )
1442
+
1443
+ if entry is not _UNASSIGNED :
1444
+ # We have a cache entry.
1445
+ if TYPE_CHECKING :
1446
+ assert isinstance (entry , _DispatchCacheEntry )
1447
+ output = self ._output_from_cache_entry (state , entry , key , func , args )
1448
+ FakeTensorMode .cache_hits += 1
1449
+ if self .cache_crosscheck_enabled :
1450
+ # For debugging / testing: Validate that the output synthesized
1451
+ # from the cache matches the output created by normal dispatch.
1452
+ with disable_fake_tensor_cache (self ):
1453
+ self ._crosscheck_cache_output (output , func , types , args , kwargs )
1454
+ return output
1455
+
1456
+ # We don't have a cache entry.
1457
+ try :
1458
+ self ._validate_cache_key (func , args , kwargs )
1459
+ except _BypassDispatchCache as e :
1460
+ # We ran "extra" checks on the cache key and determined that it's no
1461
+ # good. Record the reason and mark it so we don't bother validating
1462
+ # again.
1463
+ FakeTensorMode .cache_bypasses [e .reason ] += 1
1464
+ set_cache_key (cache , key , None )
1465
+ return self ._dispatch_impl (func , types , args , kwargs )
1466
+
1467
+ output = self ._dispatch_impl (func , types , args , kwargs )
1468
+ try :
1469
+ entry = self ._make_cache_entry (state , key , func , args , kwargs , output )
1470
+ except _BypassDispatchCache as e :
1471
+ # We had trouble making the cache entry. Record the reason and mark
1472
+ # it - but note that at this point we DO have a valid output already
1473
+ # so no reason to recompute it.
1447
1474
FakeTensorMode .cache_bypasses [e .reason ] += 1
1475
+ set_cache_key (cache , key , None )
1476
+ return output
1477
+
1478
+ set_cache_key (cache , key , entry )
1479
+ FakeTensorMode .cache_misses += 1
1448
1480
1449
1481
if output is _UNASSIGNED :
1450
1482
output = self ._dispatch_impl (func , types , args , kwargs )
@@ -1634,17 +1666,15 @@ def _validate_output_for_cache_entry(
1634
1666
kwargs : Mapping [str , object ],
1635
1667
output : Optional [FakeTensor ],
1636
1668
) -> None :
1637
- from torch .fx .experimental .symbolic_shapes import has_free_unbacked_symbols
1638
-
1669
+ # Is this even possible?
1639
1670
if isinstance (output , (int , type (None ))):
1640
1671
return
1641
1672
1642
- if isinstance (output , torch .SymInt ):
1643
- if has_free_unbacked_symbols (output ):
1644
- # This is unreachable but adding the check for extra safety in
1645
- # case we change code path in future.
1646
- raise _BypassDispatchCache ("unbacked symbol in output" )
1647
- return
1673
+ if _has_unrepresented_unbacked_symbols (state , output ):
1674
+ # Unbacked symbols are fine - but only if they're also represented
1675
+ # in the input. If there are any new unbacked symbols then we can't
1676
+ # cache this output.
1677
+ raise _BypassDispatchCache ("unbacked symbol in output" )
1648
1678
1649
1679
# Some ops return tuples of Tensors, but it's rare, so avoid
1650
1680
# the complexity of caching other types.
@@ -2886,6 +2916,25 @@ def from_tensor(
2886
2916
_StoragePointer = object
2887
2917
2888
2918
2919
+ def _has_unrepresented_unbacked_symbols (
2920
+ state : _CacheKeyState , output : Optional [FakeTensor ]
2921
+ ) -> bool :
2922
+ from sympy .core .traversal import iterargs
2923
+
2924
+ from torch .fx .experimental .symbolic_shapes import _iterate_exprs
2925
+ from torch .utils ._sympy .symbol import symbol_is_type , SymT
2926
+
2927
+ for s in _iterate_exprs (output ):
2928
+ for arg in iterargs (s ):
2929
+ if arg .is_Symbol and symbol_is_type (
2930
+ arg , (SymT .UNBACKED_INT , SymT .UNBACKED_FLOAT )
2931
+ ):
2932
+ if arg not in state .known_symbols :
2933
+ return True
2934
+
2935
+ return False
2936
+
2937
+
2889
2938
# NB: returns fake tensors
2890
2939
def run_fallback_kernel (
2891
2940
fake_mode : FakeTensorMode ,
@@ -2951,6 +3000,23 @@ def map_out(e: T) -> Union[T, FakeTensor]:
2951
3000
return pytree .tree_map (map_out , r )
2952
3001
2953
3002
3003
+ def _set_cache_key_for_shape_env (
3004
+ cache : dict [_DispatchCacheKey , Optional [_DispatchCacheEntry ]],
3005
+ key : _DispatchCacheKey ,
3006
+ entry : Optional [_DispatchCacheEntry ],
3007
+ ) -> None :
3008
+ key .strip_shape_env ()
3009
+ cache [key ] = entry
3010
+
3011
+
3012
+ def _set_cache_key (
3013
+ cache : dict [_DispatchCacheKey , Optional [_DispatchCacheEntry ]],
3014
+ key : _DispatchCacheKey ,
3015
+ entry : Optional [_DispatchCacheEntry ],
3016
+ ) -> None :
3017
+ cache [key ] = entry
3018
+
3019
+
2954
3020
# Just for use to allow copying a module to fake tensors,
2955
3021
# does not apply elsewhere
2956
3022
class FakeCopyMode (TorchFunctionMode ):
0 commit comments