1
- from __future__ import annotations
2
-
3
1
import base64
4
2
import copyreg
5
3
import dataclasses
24
22
import threading
25
23
import warnings
26
24
from bisect import bisect_right
25
+ from collections .abc import Generator , KeysView , Sequence
26
+ from concurrent .futures import Future
27
27
from copy import copy
28
28
from ctypes import c_void_p , CDLL , cdll
29
29
from datetime import timedelta
107
107
)
108
108
else :
109
109
110
- def log_global_cache_errors (* args : Any , ** kwargs : Any ) -> None : # type: ignore[misc]
110
+ def log_global_cache_errors (
111
+ * args : Any , ** kwargs : Any
112
+ ) -> None :
111
113
pass
112
114
113
- def log_global_cache_stats (* args : Any , ** kwargs : Any ) -> None : # type: ignore[misc]
115
+ def log_global_cache_stats (* args : Any , ** kwargs : Any ) -> None :
114
116
pass
115
117
116
- def log_global_cache_vals (* args : Any , ** kwargs : Any ) -> None : # type: ignore[misc]
118
+ def log_global_cache_vals (* args : Any , ** kwargs : Any ) -> None :
117
119
pass
118
120
119
- def use_global_cache () -> bool : # type: ignore[misc]
121
+ def use_global_cache () -> bool :
120
122
return False
121
123
122
124
123
125
if TYPE_CHECKING :
124
- from collections .abc import Generator , KeysView , Sequence
125
- from concurrent .futures import Future
126
-
127
126
from .compile_fx import _CompileFxKwargs , CompiledFxGraph
128
127
from .graph import GraphLowering
129
128
from .ir import ChoiceCaller
@@ -262,11 +261,11 @@ def get_global_cache(self) -> dict[str, Any]:
262
261
263
262
def lookup (
264
263
self ,
265
- choices : list [ChoiceCaller ],
264
+ choices : list [" ChoiceCaller" ],
266
265
op : str ,
267
266
inputs : str ,
268
- benchmark : Optional [Callable [[Any ], dict [ChoiceCaller , float ]]],
269
- ) -> dict [ChoiceCaller , float ]:
267
+ benchmark : Optional [Callable [[Any ], dict [" ChoiceCaller" , float ]]],
268
+ ) -> dict [" ChoiceCaller" , float ]:
270
269
"""
271
270
Check to see if we have benchmarked the given choice callers. For each
272
271
choice caller:
@@ -612,7 +611,7 @@ def get_hash(self, obj: Any) -> str:
612
611
serialized_data = self .dumps (obj )
613
612
return sha256_hash (serialized_data )
614
613
615
- def debug_lines (self , inp : FxGraphHashDetails ) -> list [str ]:
614
+ def debug_lines (self , inp : " FxGraphHashDetails" ) -> list [str ]:
616
615
"""
617
616
Get a printable string describing in more detail all the attributes
618
617
comprising an object. Useful for debugging when one graph hashes
@@ -729,8 +728,8 @@ class FxGraphHashDetails:
729
728
def __init__ (
730
729
self ,
731
730
gm : torch .fx .GraphModule ,
732
- example_inputs : Sequence [InputType ],
733
- fx_kwargs : _CompileFxKwargs ,
731
+ example_inputs : Sequence [" InputType" ],
732
+ fx_kwargs : " _CompileFxKwargs" ,
734
733
inputs_to_check : Sequence [int ],
735
734
) -> None :
736
735
self .gm = gm
@@ -746,7 +745,8 @@ def __init__(
746
745
if type (v ) in (set , OrderedSet ): # noqa: set_linter
747
746
# Special case to handle set params. Python sets can't be
748
747
# ordered, so sort the elements and store them in a proxy.
749
- self .fx_kwargs [k ] = OrderedSetHolder (sorted (v )) # type: ignore[call-overload]
748
+ assert isinstance (v , Sequence )
749
+ self .fx_kwargs [k ] = OrderedSetHolder (sorted (v ))
750
750
else :
751
751
self .fx_kwargs [k ] = v
752
752
@@ -847,8 +847,8 @@ def _get_custom_pass_detail(
847
847
848
848
def compiled_fx_graph_hash (
849
849
gm : torch .fx .GraphModule ,
850
- example_inputs : Sequence [InputType ],
851
- fx_kwargs : _CompileFxKwargs ,
850
+ example_inputs : Sequence [" InputType" ],
851
+ fx_kwargs : " _CompileFxKwargs" ,
852
852
inputs_to_check : Sequence [int ],
853
853
) -> tuple [str , list [str ]]:
854
854
"""
@@ -940,7 +940,7 @@ def _get_tmp_dir_for_key(key: str) -> str:
940
940
return os .path .join (FxGraphCache ._get_tmp_dir (), key [1 :3 ], key )
941
941
942
942
@staticmethod
943
- def _filter_backed_symints (inputs : Sequence [InputType ]) -> list [torch .SymInt ]:
943
+ def _filter_backed_symints (inputs : Sequence [" InputType" ]) -> list [torch .SymInt ]:
944
944
"""
945
945
Get the backed SymInt objects from the input list. Note that we can never
946
946
have guards that depend on unbacked symint.
@@ -960,11 +960,11 @@ def _get_shape_env() -> Optional[ShapeEnv]:
960
960
@staticmethod
961
961
def _lookup_graph (
962
962
key : str ,
963
- example_inputs : Sequence [InputType ],
963
+ example_inputs : Sequence [" InputType" ],
964
964
local : bool ,
965
- remote_cache : Optional [RemoteCache [JsonDataTy ]],
966
- constants : CompiledFxGraphConstants ,
967
- ) -> tuple [Optional [CompiledFxGraph ], dict [str , Any ]]:
965
+ remote_cache : Optional [" RemoteCache[JsonDataTy]" ],
966
+ constants : " CompiledFxGraphConstants" ,
967
+ ) -> tuple [Optional [" CompiledFxGraph" ], dict [str , Any ]]:
968
968
"""
969
969
Lookup a compiled graph in the cache by key. On a hit, return the
970
970
deserialized CompiledFxGraph object. On a miss, return None.
@@ -976,7 +976,7 @@ def _lookup_graph(
976
976
hints = [hint_int (s ) for s in symints ]
977
977
978
978
def iterate_over_candidates () -> Generator [
979
- tuple [CompiledFxGraph , bytes ], None , None
979
+ tuple [" CompiledFxGraph" , bytes ], None , None
980
980
]:
981
981
if local :
982
982
subdir = FxGraphCache ._get_tmp_dir_for_key (key )
@@ -1113,10 +1113,10 @@ def _write_to_local_cache(key: str, content: bytes) -> None:
1113
1113
@staticmethod
1114
1114
def _save_graph (
1115
1115
key : str ,
1116
- compiled_graph : OutputCode ,
1117
- example_inputs : Sequence [InputType ],
1116
+ compiled_graph : " OutputCode" ,
1117
+ example_inputs : Sequence [" InputType" ],
1118
1118
local : bool ,
1119
- remote_cache : Optional [RemoteCache [JsonDataTy ]],
1119
+ remote_cache : Optional [" RemoteCache[JsonDataTy]" ],
1120
1120
) -> None :
1121
1121
"""
1122
1122
Store a serialized CompiledFxGraph on disk.
@@ -1229,8 +1229,8 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None:
1229
1229
@staticmethod
1230
1230
def prepare_key (
1231
1231
gm : torch .fx .GraphModule ,
1232
- example_inputs : Sequence [InputType ],
1233
- fx_kwargs : _CompileFxKwargs ,
1232
+ example_inputs : Sequence [" InputType" ],
1233
+ fx_kwargs : " _CompileFxKwargs" ,
1234
1234
inputs_to_check : Sequence [int ],
1235
1235
remote : bool ,
1236
1236
) -> tuple [Optional [tuple [str , list [str ]]], dict [str , Any ]]:
@@ -1264,7 +1264,7 @@ def prepare_key(
1264
1264
return (key , debug_lines ), {}
1265
1265
1266
1266
@staticmethod
1267
- def get_remote_cache () -> Optional [RemoteCache [JsonDataTy ]]:
1267
+ def get_remote_cache () -> Optional [" RemoteCache[JsonDataTy]" ]:
1268
1268
"""
1269
1269
Attempts to load the remote cache, returns None on error.
1270
1270
"""
@@ -1280,12 +1280,12 @@ def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
1280
1280
def load_with_key (
1281
1281
key : str ,
1282
1282
debug_lines : list [str ],
1283
- example_inputs : Sequence [InputType ],
1283
+ example_inputs : Sequence [" InputType" ],
1284
1284
local : bool ,
1285
- remote_cache : Optional [RemoteCache [JsonDataTy ]],
1285
+ remote_cache : Optional [" RemoteCache[JsonDataTy]" ],
1286
1286
is_backward : bool ,
1287
- constants : CompiledFxGraphConstants ,
1288
- ) -> tuple [Optional [CompiledFxGraph ], dict [str , Any ]]:
1287
+ constants : " CompiledFxGraphConstants" ,
1288
+ ) -> tuple [Optional [" CompiledFxGraph" ], dict [str , Any ]]:
1289
1289
"""
1290
1290
Lookup the graph with the given key, and return results and metadata.
1291
1291
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
@@ -1392,7 +1392,7 @@ class AotCodeCompiler:
1392
1392
@classmethod
1393
1393
def compile (
1394
1394
cls ,
1395
- graph : GraphLowering ,
1395
+ graph : " GraphLowering" ,
1396
1396
wrapper_code : str ,
1397
1397
kernel_code : str ,
1398
1398
serialized_extern_kernel_nodes : Optional [str ],
@@ -1966,7 +1966,7 @@ def convert_arg(arg: Any) -> Any:
1966
1966
result = [torch .tensor ([]) if r is None else r for r in result ]
1967
1967
for i , r in enumerate (result ):
1968
1968
assert isinstance (r , torch .Tensor ), op + " returns a list of non-tensors"
1969
- return torch ._C ._aoti .unsafe_alloc_void_ptrs_from_tensors (result ) # type: ignore[arg-type]
1969
+ return torch ._C ._aoti .unsafe_alloc_void_ptrs_from_tensors (result )
1970
1970
else :
1971
1971
assert isinstance (result , torch .Tensor ), op + " returns a non-tensor"
1972
1972
return torch ._C ._aoti .unsafe_alloc_void_ptr_from_tensor (result )
@@ -2308,7 +2308,8 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType:
2308
2308
assert spec is not None
2309
2309
module = importlib .util .module_from_spec (spec )
2310
2310
sys .modules [module_name ] = module
2311
- spec .loader .exec_module (module ) # type: ignore[union-attr]
2311
+ assert spec .loader is not None
2312
+ spec .loader .exec_module (module )
2312
2313
return module
2313
2314
2314
2315
@classmethod
@@ -2515,7 +2516,9 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
2515
2516
)
2516
2517
2517
2518
@classmethod
2518
- def _codegen_buffer (cls , name : str , arg : HalideInputSpec , cuda : bool ) -> list [str ]:
2519
+ def _codegen_buffer (
2520
+ cls , name : str , arg : "HalideInputSpec" , cuda : bool
2521
+ ) -> list [str ]:
2519
2522
assert arg .shape is not None
2520
2523
assert arg .stride is not None and len (arg .shape ) == len (arg .stride )
2521
2524
assert arg .offset is not None
@@ -2549,7 +2552,7 @@ def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[st
2549
2552
]
2550
2553
2551
2554
@classmethod
2552
- def _codegen_glue (cls , meta : HalideMeta , headerfile : object ) -> str :
2555
+ def _codegen_glue (cls , meta : " HalideMeta" , headerfile : object ) -> str :
2553
2556
is_cuda = meta .is_cuda ()
2554
2557
assert is_cuda is ("user_context" in meta .target )
2555
2558
assert "no_runtime" in meta .target
@@ -2657,7 +2660,7 @@ def find_header(name: str) -> str:
2657
2660
2658
2661
@classmethod
2659
2662
def generate_halide_async (
2660
- cls , meta : HalideMeta , source_code : str , submit_fn : Any = None
2663
+ cls , meta : " HalideMeta" , source_code : str , submit_fn : Any = None
2661
2664
) -> Callable [[], Any ]:
2662
2665
dirpath = Path (
2663
2666
get_path (
@@ -2797,6 +2800,7 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
2797
2800
job ()
2798
2801
except subprocess .SubprocessError as e :
2799
2802
if os .environ .get ("HALIDE_REPRO" ) == "1" :
2803
+ cmd : list [Any ]
2800
2804
python , script , * cmd = getattr (e , "cmd" , ("" , "" , "" ))
2801
2805
if os .path .basename (python ).startswith ("python" ):
2802
2806
code = open (script ).read ()
@@ -2807,7 +2811,9 @@ class Out:
2807
2811
def __repr__ (self ) -> str :
2808
2812
return "out"
2809
2813
2810
- cmd [cmd .index ("-o" ) + 1 ] = Out () # type: ignore[call-overload]
2814
+ ci = cmd .index ("-o" )
2815
+ assert isinstance (ci , int )
2816
+ cmd [ci + 1 ] = Out ()
2811
2817
repl = textwrap .indent (
2812
2818
textwrap .dedent (
2813
2819
f"""\
@@ -2934,7 +2940,7 @@ def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]:
2934
2940
2935
2941
def _load_triton_kernel_from_source (
2936
2942
kernel_name : str , source_code : str
2937
- ) -> CachingAutotuner :
2943
+ ) -> " CachingAutotuner" :
2938
2944
return getattr (PyCodeCache .load (source_code ), kernel_name )
2939
2945
2940
2946
@@ -3349,7 +3355,7 @@ def __init__(
3349
3355
self .result_fn = result_fn
3350
3356
self .future = future
3351
3357
3352
- def result (self ) -> Callable [..., Any ]: # type: ignore[override]
3358
+ def result (self ) -> Callable [..., Any ]:
3353
3359
return self .result_fn ()
3354
3360
3355
3361
@@ -3358,7 +3364,7 @@ class StaticAutotunerFuture(CodeCacheFuture):
3358
3364
A statically launchable CachingAutotuner, loaded from TritonBundler
3359
3365
"""
3360
3366
3361
- def __init__ (self , static_autotuner : CachingAutotuner ) -> None :
3367
+ def __init__ (self , static_autotuner : " CachingAutotuner" ) -> None :
3362
3368
# Pickled version of CachingAutotuner
3363
3369
self .static_autotuner = static_autotuner
3364
3370
# This needs to be set in AsyncCompile.triton, in case
@@ -3367,10 +3373,10 @@ def __init__(self, static_autotune
7C08
r: CachingAutotuner) -> None:
3367
3373
# since it can be very large.
3368
3374
self .reload_kernel_from_src : Optional [Callable [[], Any ]] = None
3369
3375
3370
- def result (self ) -> CachingAutotuner :
3376
+ def result (self ) -> " CachingAutotuner" :
3371
3377
assert self .reload_kernel_from_src is not None
3372
3378
with dynamo_timed ("StaticAutotunerFuture.warm_precompile" ):
3373
- self .static_autotuner .precompile ( # type: ignore[union-attr]
3379
+ self .static_autotuner .precompile (
3374
3380
warm_cache_only = False ,
3375
3381
reload_kernel = self .reload_kernel_from_src ,
3376
3382
static_triton_bundle_key = None , # no need to save again
0 commit comments