74
74
raise e
75
75
76
76
77
- class _Unassigned :
78
- pass
79
-
80
-
81
- _UNASSIGNED = _Unassigned ()
82
-
83
77
DimList = list
84
78
85
79
pytree = torch .utils ._pytree
@@ -1118,7 +1112,7 @@ class _DispatchCacheEntryOutputInfo:
1118
1112
1119
1113
@dataclass_slots
1120
1114
@dataclass (frozen = True )
1121
- class _DispatchCacheEntry :
1115
+ class _DispatchCacheValidEntry :
1122
1116
"""
1123
1117
Entry type for the FakeTensor dispatch cache. It supports two types of outputs
1124
1118
1) tensor
@@ -1131,6 +1125,20 @@ class _DispatchCacheEntry:
1131
1125
is_output_tuple : bool = False
1132
1126
1133
1127
1128
+ @dataclass_slots
1129
+ @dataclass (frozen = True )
1130
+ class _DispatchCacheBypassEntry :
1131
+ """
1132
+ Entry type for a negative cache entry.
1133
+ """
1134
+
1135
+ reason : str
1136
+
1137
+
1138
+ if TYPE_CHECKING :
1139
+ _DispatchCacheEntry = Union [_DispatchCacheValidEntry , _DispatchCacheBypassEntry ]
1140
+
1141
+
1134
1142
@dataclass_slots
1135
1143
@dataclass (frozen = True )
1136
1144
class _BypassDispatchCache (Exception ):
@@ -1418,37 +1426,64 @@ def _cached_dispatch_impl(
1418
1426
Lookup a cache entry for the given arguments. If none exists, dispatch
1419
1427
and cache the result (if the result is eligible for caching).
1420
1428
"""
1421
- output : object = _UNASSIGNED
1422
1429
try :
1423
1430
state = _CacheKeyState (self .shape_env )
1424
1431
key = self ._cache_key (state , func , args , kwargs )
1425
- if state .cache_on_shape_env ():
1426
- assert state .shape_env is not None
1427
- cache = state .shape_env .fake_tensor_cache
1428
- else :
1429
- cache = FakeTensorMode .cache
1430
- entry = cache .get (key , None )
1431
- if entry is not None :
1432
- output = self ._output_from_cache_entry (state , entry , key , func , args )
1433
- FakeTensorMode .cache_hits += 1
1434
- if self .cache_crosscheck_enabled :
1435
- # For debugging / testing: Validate that the output synthesized
1436
- # from the cache matches the output created by normal dispatch.
1437
- with disable_fake_tensor_cache (self ):
1438
- self ._crosscheck_cache_output (output , func , types , args , kwargs )
1439
- else :
1440
- self ._validate_cache_key (func , args , kwargs )
1441
- output = self ._dispatch_impl (func , types , args , kwargs )
1442
- entry = self ._make_cache_entry (state , key , func , args , kwargs , output )
1443
- key .strip_shape_env ()
1444
- cache [key ] = entry
1445
- FakeTensorMode .cache_misses += 1
1446
1432
except _BypassDispatchCache as e :
1433
+ # We couldn't create the cache key at all
1434
+ FakeTensorMode .cache_bypasses [e .reason ] += 1
1435
+ return self ._dispatch_impl (func , types , args , kwargs )
1436
+
1437
+ if state .cache_on_shape_env ():
1438
+ assert state .shape_env is not None
1439
+ cache = state .shape_env .fake_tensor_cache
1440
+ set_cache_key = _set_cache_key_for_shape_env
1441
+ else :
1442
+ cache = FakeTensorMode .cache
1443
+ set_cache_key = _set_cache_key
1444
+ entry = cache .get (key , None )
1445
+
1446
+ if entry is not None :
1447
+ if isinstance (entry , _DispatchCacheBypassEntry ):
1448
+ # This represents a negative cache entry - we already saw that the
1449
+ # output is uncachable. Compute it from first principals.
1450
+ FakeTensorMode .cache_bypasses [entry .reason ] += 1
1451
+ return self ._dispatch_impl (func , types , args , kwargs )
1452
+
1453
+ # We have a cache entry.
1454
+ output = self ._output_from_cache_entry (state , entry , key , func , args )
1455
+ FakeTensorMode .cache_hits += 1
1456
+ if self .cache_crosscheck_enabled :
1457
+ # For debugging / testing: Validate that the output synthesized
1458
+ # from the cache matches the output created by normal dispatch.
1459
+ with disable_fake_tensor_cache (self ):
1460
+ self ._crosscheck_cache_output (output , func , types , args , kwargs )
1461
+ return output
1462
+
1463
+ # We don't have a cache entry.
1464
+ output = self ._dispatch_impl (func , types , args , kwargs )
1465
+
1466
+ try :
1467
+ self ._validate_cache_key (func , args , kwargs )
1468
+ except _BypassDispatchCache as e :
1469
+ # We ran "extra" checks on the cache key and determined that it's no
1470
+ # good. Record the reason and mark it so we don't bother validating
1471
+ # again.
1447
1472
FakeTensorMode .cache_bypasses [e .reason ] += 1
1473
+ set_cache_key (cache , key , _DispatchCacheBypassEntry (e .reason ))
1474
+ return output
1448
1475
1449
- if output is _UNASSIGNED :
1450
- output = self ._dispatch_impl (func , types , args , kwargs )
1476
+ try :
1477
+ entry = self ._make_cache_entry (state , key , func , args , kwargs , output )
1478
+ except _BypassDispatchCache as e :
1479
+ # We had trouble making the cache entry. Record the reason and mark
1480
+ # it.
1481
+ FakeTensorMode .cache_bypasses [e .reason ] += 1
1482
+ set_cache_key (cache , key , _DispatchCacheBypassEntry (e .reason ))
1483
+ return output
1451
1484
1485
+ set_cache_key (cache , key , entry )
1486
+ FakeTensorMode .cache_misses += 1
1452
1487
return output
1453
1488
1454
1489
def _cache_key (
@@ -1634,17 +1669,15 @@ def _validate_output_for_cache_entry(
1634
1669
kwargs : Mapping [str , object ],
1635
1670
output : Optional [FakeTensor ],
1636
1671
) -> None :
1637
- from torch .fx .experimental .symbolic_shapes import has_free_unbacked_symbols
1638
-
1672
+ # Is this even possible?
1639
1673
if isinstance (output , (int , type (None ))):
1640
1674
return
1641
1675
1642
- if isinstance (output , torch .SymInt ):
1643
- if has_free_unbacked_symbols (output ):
1644
- # This is unreachable but adding the check for extra safety in
1645
- # case we change code path in future.
1646
- raise _BypassDispatchCache ("unbacked symbol in output" )
1647
- return
1676
+ if _has_unrepresented_symbols (state , output ):
1677
+ # Unbacked symbols are fine - but only if they're also represented
1678
+ # in the input. If there are any new unbacked symbols then we can't
1679
+ # cache this output.
1680
+ raise _BypassDispatchCache ("unrepresented symbol in output" )
1648
1681
1649
1682
# Some ops return tuples of Tensors, but it's rare, so avoid
1650
1683
# the complexity of caching other types.
@@ -1718,7 +1751,7 @@ def _get_output_info_for_cache_entry(
1718
1751
# we can synthesize a tensor here and do the checks on that instance.
1719
1752
# This approach keeps the (more frequent) cache-hit path as lightweight
1720
1753
# as possible.
1721
- entry_for_synth_output = _DispatchCacheEntry (
1754
+ entry_for_synth_output = _DispatchCacheValidEntry (
1722
1755
output_infos = (entry ,), is_output_tuple = False
1723
1756
)
1724
1757
synth_output = self ._output_from_cache_entry (
@@ -1742,7 +1775,7 @@ def _make_cache_entry(
1742
1775
args : Sequence [object ],
1743
1776
kwargs : Mapping [str , object ],
1744
1777
output : Optional [FakeTensor ],
1745
- ) -> _DispatchCacheEntry :
1778
+ ) -> _DispatchCacheValidEntry :
1746
1779
"""
1747
1780
Make a cache entry object for the given 'output' Tensor. Raises
1748
1781
_BypassDispatchCache if the output tensor has characteristics that
@@ -1773,7 +1806,7 @@ def _make_cache_entry(
1773
1806
output_info = _DispatchCacheEntryOutputInfo (
1774
1807
inplace_idx = None , metadata = None , view_idx = None , constant_value = output
1775
1808
)
1776
- return _DispatchCacheEntry (
1809
+ return _DispatchCacheValidEntry (
1777
1810
output_infos = (output_info ,), is_output_tuple = False
1778
1811
)
1779
1812
@@ -1794,15 +1827,15 @@ def _make_cache_entry(
1794
1827
)
1795
1828
for out_elem in output
1796
1829
]
1797
- return _DispatchCacheEntry (
1830
+ return _DispatchCacheValidEntry (
1798
1831
output_infos = tuple (output_infos ), is_output_tuple = True
1799
1832
)
1800
1833
1801
1834
else :
1802
1835
output_info = self ._get_output_info_for_cache_entry (
1803
1836
state , key , func , args , kwargs , output
1804
1837
)
1805
- return _DispatchCacheEntry (
1838
+ return _DispatchCacheValidEntry (
1806
1839
output_infos = (output_info ,), is_output_tuple = False
1807
1840
)
1808
1841
@@ -1882,7 +1915,7 @@ def check_value(
1882
1915
def _output_from_cache_entry (
1883
1916
self ,
1884
1917
state : _CacheKeyState ,
1885
- entry : _DispatchCacheEntry ,
1918
+ entry : _DispatchCacheValidEntry ,
1886
1919
key : _DispatchCacheKey ,
1887
1920
func : OpOverload ,
1888
1921
args : Sequence [object ],
@@ -2886,6 +2919,19 @@ def from_tensor(
2886
2919
_StoragePointer = object
2887
2920
2888
2921
2922
+ def _has_unrepresented_symbols (
2923
+ state : _CacheKeyState , output : Optional [FakeTensor ]
2924
+ ) -> bool :
2925
+ from torch .fx .experimental .symbolic_shapes import _iterate_exprs
2926
+
2927
+ for s in _iterate_exprs (output ):
2928
+ for symbol in s .free_symbols :
2929
+ if symbol not in state .known_symbols :
2930
+ return True
2931
+
2932
+ return False
2933
+
2934
+
2889
2935
# NB: returns fake tensors
2890
2936
def run_fallback_kernel (
2891
2937
fake_mode : FakeTensorMode ,
@@ -2951,6 +2997,23 @@ def map_out(e: T) -> Union[T, FakeTensor]:
2951
2997
return pytree .tree_map (map_out , r )
2952
2998
2953
2999
3000
+ def _set_cache_key_for_shape_env (
3001
+ cache : dict [_DispatchCacheKey , _DispatchCacheEntry ],
3002
+ key : _DispatchCacheKey ,
3003
+ entry : _DispatchCacheEntry ,
3004
+ ) -> None :
3005
+ key .strip_shape_env ()
3006
+ cache [key ] = entry
3007
+
3008
+
3009
+ def _set_cache_key (
3010
+ cache : dict [_DispatchCacheKey , _DispatchCacheEntry ],
3011
+ key : _DispatchCacheKey ,
3012
+ entry : _DispatchCacheEntry ,
3013
+ ) -> None :
3014
+ cache [key ] = entry
3015
+
3016
+
2954
3017
# Just for use to allow copying a module to fake tensors,
2955
3018
# does not apply elsewhere
2956
3019
class FakeCopyMode (TorchFunctionMode ):
0 commit comments