1
- from __future__ import annotations
2
-
3
1
import base64
4
2
import copyreg
5
3
import dataclasses
25
23
import threading
26
24
import warnings
27
25
from bisect import bisect_right
26
+ from collections .abc import Generator , KeysView , Sequence
27
+ from concurrent .futures import Future
28
28
from copy import copy
29
29
from ctypes import c_void_p , CDLL , cdll
30
30
from datetime import timedelta
112
112
)
113
113
else :
114
114
115
- def log_global_cache_errors (* args : Any , ** kwargs : Any ) -> None : # type: ignore[misc]
115
+ def log_global_cache_errors (* args : Any , ** kwargs : Any ) -> None :
116
116
pass
117
117
118
- def log_global_cache_stats (* args : Any , ** kwargs : Any ) -> None : # type: ignore[misc]
118
+ def log_global_cache_stats (* args : Any , ** kwargs : Any ) -> None :
119
119
pass
120
120
121
- def log_global_cache_vals (* args : Any , ** kwargs : Any ) -> None : # type: ignore[misc]
121
+ def log_global_cache_vals (* args : Any , ** kwargs : Any ) -> None :
122
122
pass
123
123
124
- def use_global_cache () -> bool : # type: ignore[misc]
124
+ def use_global_cache () -> bool :
125
125
return False
126
126
127
127
128
128
T = TypeVar ("T" )
129
129
130
130
if TYPE_CHECKING :
131
- from collections .abc import Generator , KeysView , Sequence
132
- from concurrent .futures import Future
133
-
134
131
from .compile_fx import _CompileFxKwargs
135
132
from .graph import GraphLowering
136
133
from .ir import ChoiceCaller
@@ -267,11 +264,11 @@ def get_global_cache(self) -> dict[str, Any]:
267
264
268
265
def lookup (
269
266
self ,
270
- choices : list [ChoiceCaller ],
267
+ choices : list [" ChoiceCaller" ],
271
268
op : str ,
272
269
inputs : str ,
273
- benchmark : Optional [Callable [[Any ], dict [ChoiceCaller , float ]]],
274
- ) -> dict [ChoiceCaller , float ]:
270
+ benchmark : Optional [Callable [[Any ], dict [" ChoiceCaller" , float ]]],
271
+ ) -> dict [" ChoiceCaller" , float ]:
275
272
"""
276
273
Check to see if we have benchmarked the given choice callers. For each
277
274
choice caller:
@@ -617,7 +614,7 @@ def get_hash(self, obj: Any) -> str:
617
614
serialized_data = self .dumps (obj )
618
615
return sha256_hash (serialized_data )
619
616
620
- def debug_lines (self , inp : FxGraphHashDetails ) -> list [str ]:
617
+ def debug_lines (self , inp : " FxGraphHashDetails" ) -> list [str ]:
621
618
"""
622
619
Get a printable string describing in more detail all the attributes
623
620
comprising an object. Useful for debugging when one graph hashes
@@ -652,7 +649,7 @@ def get_str(obj: Any) -> str:
652
649
653
650
654
651
def build_code_hash (
655
- roots : list [str ] | None , prefix : str , hasher : hashlib ._Hash
652
+ roots : Optional [ list [str ]] , prefix : str , hasher : " hashlib._Hash"
656
653
) -> None :
657
654
for lib in sorted (pkgutil .iter_modules (roots , prefix ), key = lambda x : x .name ):
658
655
spec = lib .module_finder .find_spec (lib .name , None )
@@ -759,8 +756,8 @@ class FxGraphHashDetails:
759
756
def __init__ (
760
757
self ,
761
758
gm : torch .fx .GraphModule ,
762
- example_inputs : Sequence [InputType ],
763
- fx_kwargs : _CompileFxKwargs ,
759
+ example_inputs : Sequence [" InputType" ],
760
+ fx_kwargs : " _CompileFxKwargs" ,
764
761
inputs_to_check : Sequence [int ],
765
762
) -> None :
766
763
self .gm = gm
@@ -877,8 +874,8 @@ def _get_custom_pass_detail(
877
874
878
875
def compiled_fx_graph_hash (
879
876
gm : torch .fx .GraphModule ,
880
- example_inputs : Sequence [InputType ],
881
- fx_kwargs : _CompileFxKwargs ,
877
+ example_inputs : Sequence [" InputType" ],
878
+ fx_kwargs : " _CompileFxKwargs" ,
882
879
inputs_to_check : Sequence [int ],
883
880
) -> tuple [str , list [str ]]:
884
881
"""
@@ -931,14 +928,14 @@ class GuardedCache(Generic[T]):
931
928
"""
932
929
933
930
@classmethod
934
- def _get_tmp_dir_for_key (cls : type [GuardedCache [T ]], _key : str ) -> str :
931
+ def _get_tmp_dir_for_key (cls : type [" GuardedCache[T]" ], _key : str ) -> str :
935
932
raise NotImplementedError ("Implement _get_tmp_dir_for_key on parent class" )
936
933
937
934
@classmethod
938
935
def iterate_over_candidates (
939
- cls : type [GuardedCache [T ]],
936
+ cls : type [" GuardedCache[T]" ],
940
937
local : bool ,
941
- remote_cache : Optional [RemoteCache [JsonDataTy ]],
938
+ remote_cache : Optional [" RemoteCache[JsonDataTy]" ],
942
939
key : str ,
943
940
) -> Generator [tuple [T , bytes ], None , None ]:
944
941
if local :
@@ -970,10 +967,10 @@ def iterate_over_candidates(
970
967
971
968
@classmethod
972
969
def find_guarded_entry (
973
- cls : type [GuardedCache [T ]],
970
+ cls : type [" GuardedCache[T]" ],
974
971
key : str ,
975
972
local : bool ,
976
- remote_cache : Optional [RemoteCache [JsonDataTy ]],
973
+ remote_cache : Optional [" RemoteCache[JsonDataTy]" ],
977
974
evaluate_guards : Callable [[str , Union [list [int ], list [torch .SymInt ]]], bool ],
978
975
hints : list [int ],
979
976
) -> tuple [Optional [T ], Optional [bytes ], dict [str , str ]]:
@@ -1031,7 +1028,7 @@ def find_guarded_entry(
1031
1028
1032
1029
@classmethod
1033
1030
def _filter_backed_symints (
1034
- cls : type [GuardedCache [T ]], inputs : Sequence [InputType ]
1031
+ cls : type [" GuardedCache[T]" ], inputs : Sequence [" InputType" ]
1035
1032
) -> list [torch .SymInt ]:
1036
1033
"""
1037
1034
Get the backed SymInt objects from the input list. Note that we can never
@@ -1040,7 +1037,7 @@ def _filter_backed_symints(
1040
1037
return [s for s in inputs if isinstance (s , torch .SymInt ) and has_hint (s )]
1041
1038
1042
1039
@classmethod
1043
- def _get_shape_env (cls : type [GuardedCache [T ]]) -> Optional [ShapeEnv ]:
1040
+ def _get_shape_env (cls : type [" GuardedCache[T]" ]) -> Optional [ShapeEnv ]:
1044
1041
"""
1045
1042
Helper to get the shape env
6377
from the tracing context.
1046
1043
"""
@@ -1088,7 +1085,7 @@ def _get_tmp_dir() -> str:
1088
1085
return os .path .join (cache_dir (), "fxgraph" )
1089
1086
1090
1087
@classmethod
1091
- def _get_tmp_dir_for_key (cls : type [FxGraphCache ], key : str ) -> str :
1088
+ def _get_tmp_dir_for_key (cls : type [" FxGraphCache" ], key : str ) -> str :
1092
1089
"""
1093
1090
Return the disk location for a given cache key.
1094
1091
"""
@@ -1098,7 +1095,7 @@ def _get_tmp_dir_for_key(cls: type[FxGraphCache], key: str) -> str:
1098
1095
def cache_hit_post_compile (
1099
1096
graph : CompiledFxGraph ,
1100
1097
cache_info : dict [str , Any ],
1101
- constants : CompiledFxGraphConstants ,
1098
+ constants : " CompiledFxGraphConstants" ,
1102
1099
) -> tuple [Optional [CompiledFxGraph ], dict [str , Any ]]:
1103
1100
"""
1104
1101
Cache specific post compile steps that need to run if we find a graph in the cache
@@ -1173,10 +1170,10 @@ def cache_hit_post_compile(
1173
1170
@staticmethod
1174
1171
def _lookup_graph (
1175
1172
key : str ,
1176
- example_inputs : Sequence [InputType ],
1173
+ example_inputs : Sequence [" InputType" ],
1177
1174
local : bool ,
1178
- remote_cache : Optional [RemoteCache [JsonDataTy ]],
1179
- constants : CompiledFxGraphConstants ,
1175
+ remote_cache : Optional [" RemoteCache[JsonDataTy]" ],
1176
+ constants : " CompiledFxGraphConstants" ,
1180
1177
evaluate_guards : Optional [
1181
1178
Callable [[str , Union [list [int ], list [torch .SymInt ]]], bool ]
1182
1179
] = None ,
@@ -1245,10 +1242,10 @@ def _write_to_local_cache(key: str, content: bytes) -> None:
1245
1242
@staticmethod
1246
1243
def _save_graph (
1247
1244
key : str ,
1248
- compiled_graph : OutputCode ,
1249
- example_inputs : Sequence [InputType ],
1245
+ compiled_graph : " OutputCode" ,
1246
+ example_inputs : Sequence [" InputType" ],
1250
1247
local : bool ,
1251
- remote_cache : Optional [RemoteCache [JsonDataTy ]],
1248
+ remote_cache : Optional [" RemoteCache[JsonDataTy]" ],
1252
1249
) -> None :
1253
1250
"""
1254
1251
Store a serialized CompiledFxGraph on disk.
@@ -1361,8 +1358,8 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None:
1361
1358
@staticmethod
1362
1359
def prepare_key (
1363
1360
gm : torch .fx .GraphModule ,
1364
- example_inputs : Sequence [InputType ],
1365
- fx_kwargs : _CompileFxKwargs ,
1361
+ example_inputs : Sequence [" InputType" ],
1362
+ fx_kwargs : " _CompileFxKwargs" ,
1366
1363
inputs_to_check : Sequence [int ],
1367
1364
remote : bool ,
1368
1365
) -> tuple [Optional [tuple [str , list [str ]]], dict [str , Any ]]:
@@ -1396,7 +1393,7 @@ def prepare_key(
1396
1393
return (key , debug_lines ), {}
1397
1394
1398
1395
@staticmethod
1399
- def get_remote_cache () -> Optional [RemoteCache [JsonDataTy ]]:
1396
+ def get_remote_cache () -> Optional [" RemoteCache[JsonDataTy]" ]:
1400
1397
"""
1401
1398
Attempts to load the remote cache, returns None on error.
1402
1399
"""
@@ -1412,15 +1409,15 @@ def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
1412
1409
def load_with_key (
1413
1410
key : str ,
1414
1411
debug_lines : list [str ],
1415
- example_inputs : Sequence [InputType ],
1412
+ example_inputs : Sequence [" InputType" ],
1416
1413
local : bool ,
1417
- remote_cache : Optional [RemoteCache [JsonDataTy ]],
1414
+ remote_cache : Optional [" RemoteCache[JsonDataTy]" ],
1418
1415
is_backward : bool ,
1419
- constants : CompiledFxGraphConstants ,
1416
+ constants : " CompiledFxGraphConstants" ,
1420
1417
evaluate_guards : Optional [
1421
1418
Callable [[str , Union [list [int ], list [torch .SymInt ]]], bool ]
1422
1419
] = None ,
1423
- ) -> tuple [Optional [CompiledFxGraph ], dict [str , Any ]]:
1420
+ ) -> tuple [Optional [" CompiledFxGraph" ], dict [str , Any ]]:
1424
1421
"""
1425
1422
Lookup the graph with the given key, and return results and metadata.
1426
1423
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
@@ -1535,7 +1532,7 @@ class AotCodeCompiler:
1535
1532
@classmethod
1536
1533
def compile (
1537
1534
cls ,
1538
- graph : GraphLowering ,
1535
+ graph : " GraphLowering" ,
1539
1536
wrapper_code : str ,
1540
1537
kernel_code : str ,
1541
1538
serialized_extern_kernel_nodes : Optional [str ],
@@ -2252,7 +2249,7 @@ def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
2252
2249
raise
2253
2250
2254
2251
@classmethod
2255
- def _get_uncompiled_header (cls , device : str ) -> str | None :
2252
+ def _get_uncompiled_header (cls , device : str ) -> Optional [ str ] :
2256
2253
"""
2257
2254
Given a device type, returns the path to a CPP header file to be precompiled.
2258
2255
Currently, this is only utilized by the cpp_wrapper classes.
@@ -2472,7 +2469,8 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType:
2472
2469
assert spec is not None
2473
2470
module = importlib .util .module_from_spec (spec )
2474
2471
sys .modules [module_name ] = module
2475
- spec .loader .exec_module (module ) # type: ignore[union-attr]
2472
+ assert spec .loader is not None
2473
+ spec .loader .exec_module (module )
2476
2474
return module
2477
2475
2478
2476
@classmethod
@@ -2592,7 +2590,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
2592
2590
)
2593
2591
2594
2592
@classmethod
2595
- def _get_uncompiled_header (cls , device : str ) -> str | None :
2593
+ def _get_uncompiled_header (cls , device : str ) -> Optional [ str ] :
2596
2594
"""
2597
2595
Given a device type, returns the path to a CPP header file to be precompiled.
2598
2596
Currently, this is only utilized by the cpp_wrapper classes.
@@ -2679,7 +2677,9 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
2679
2677
)
2680
2678
2681
2679
@classmethod
2682
- def _codegen_buffer (cls , name : str , arg : HalideInputSpec , cuda : bool ) -> list [str ]:
2680
+ def _codegen_buffer (
2681
+ cls , name : str , arg : "HalideInputSpec" , cuda : bool
2682
+ ) -> list [str ]:
2683
2683
assert arg .shape is not None
2684
2684
assert arg .stride is not None and len (arg .shape ) == len (arg .stride )
2685
2685
assert arg .offset is not None
@@ -2713,7 +2713,7 @@ def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[st
2713
2713
]
2714
2714
2715
2715
@classmethod
2716
- def _codegen_glue (cls , meta : HalideMeta , headerfile : object ) -> str :
2716
+ def _codegen_glue (cls , meta : " HalideMeta" , headerfile : object ) -> str :
2717
2717
is_cuda = meta .is_cuda ()
2718
2718
assert is_cuda is ("user_context" in meta .target )
2719
2719
assert "no_runtime" in meta .target
@@ -2821,7 +2821,7 @@ def find_header(name: str) -> str:
2821
2821
2822
2822
@classmethod
2823
2823
def generate_halide_async (
2824
- cls , meta : HalideMeta , source_code : str , submit_fn : Any = None
2824
+ cls , meta : " HalideMeta" , source_code : str , submit_fn : Any = None
2825
2825
) -> Callable [[], Any ]:
2826
2826
dirpath = Path (
2827
2827
get_path (
@@ -2961,6 +2961,7 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
2961
2961
job ()
2962
2962
except subprocess .SubprocessError as e :
2963
2963
if os .environ .get ("HALIDE_REPRO" ) == "1" :
2964
+ cmd : list [Any ]
2964
2965
python , script , * cmd = getattr (e , "cmd" , ("" , "" , "" ))
2965
2966
if os .path .basename (python ).startswith ("python" ):
2966
2967
code = open (script ).read ()
@@ -2971,7 +2972,9 @@ class Out:
2971
2972
def __repr__ (self ) -> str :
2972
2973
return "out"
2973
2974
2974
- cmd [cmd .index ("-o" ) + 1 ] = Out () # type: ignore[call-overload]
2975
+ ci = cmd .index ("-o" )
2976
+ assert isinstance (ci , int )
2977
+ cmd [ci + 1 ] = Out ()
2975
2978
repl = textwrap .indent (
2976
2979
textwrap .dedent (
2977
2980
f"""\
@@ -3098,7 +3101,7 @@ def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]:
3098
3101
3099
3102
def _load_triton_kernel_from_source (
3100
3103
kernel_name : str , source_code : str
3101
- ) -> CachingAutotuner :
3104
+ ) -> " CachingAutotuner" :
3102
3105
return getattr (PyCodeCache .load (source_code ), kernel_name )
3103
3106
3104
3107
@@ -3557,7 +3560,7 @@ def __init__(
3557
3560
self .result_fn = result_fn
3558
3561
self .future = future
3559
3562
3560
- def result (self ) -> Callable [..., Any ]: # type: ignore[override]
3563
+ def result (self ) -> Callable [..., Any ]:
3561
3564
return self .result_fn ()
3562
3565
3563
3566
@@ -3566,7 +3569,7 @@ class StaticAutotunerFuture(CodeCacheFuture):
3566
3569
A statically launchable CachingAutotuner, loaded from TritonBundler
3567
3570
"""
3568
3571
3569
- def __init__ (self , static_autotuner : CachingAutotuner ) -> None :
3572
+ def __init__ (self , static_autotuner : " CachingAutotuner" ) -> None :
3570
3573
# Pickled version of CachingAutotuner
3571
3574
self .static_autotuner = static_autotuner
3572
3575
# This needs to be set in AsyncCompile.triton, in case
@@ -3575,10 +3578,10 @@ def __init__(self, static_autotuner: CachingAutotuner) -> None:
3575
3578
# since it can be very large.
3576
3579
self .reload_kernel_from_src : Optional [Callable [[], Any ]] = None
3577
3580
3578
- def result (self ) -> CachingAutotuner :
3581
+ def result (self ) -> " CachingAutotuner" :
3579
3582
assert self .reload_kernel_from_src is not None
3580
3583
with dynamo_timed ("StaticAutotunerFuture.warm_precompile" ):
3581
- self .static_autotuner .precompile ( # type: ignore[union-attr]
3584
+ self .static_autotuner .precompile (
3582
3585
warm_cache_only = False ,
3583
3586
reload_kernel = self .reload_kernel_from_src ,
3584
3587
static_triton_bundle_key = None , # no need to save again
0 commit comments