10000 [inductor] Clean typing in codegen/common.py and codecache.py · pytorch/pytorch@9b82380 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9b82380

Browse files
committed
[inductor] Clean typing in codegen/common.py and codecache.py
ghstack-source-id: 0880069 Pull Request resolved: #150767
1 parent 032ef48 commit 9b82380

File tree

2 files changed

+135
-131
lines changed

2 files changed

+135
-131
lines changed

torch/_inductor/codecache.py

Lines changed: 56 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import annotations
2-
31
import base64
42
import copyreg
53
import dataclasses
@@ -25,6 +23,8 @@
2523
import threading
2624
import warnings
2725
from bisect import bisect_right
26+
from collections.abc import Generator, KeysView, Sequence
27+
from concurrent.futures import Future
2828
from copy import copy
2929
from ctypes import c_void_p, CDLL, cdll
3030
from datetime import timedelta
@@ -112,25 +112,22 @@
112112
)
113113
else:
114114

115-
def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
115+
def log_global_cache_errors(*args: Any, **kwargs: Any) -> None:
116116
pass
117117

118-
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
118+
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None:
119119
pass
120120

121-
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
121+
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None:
122122
pass
123123

124-
def use_global_cache() -> bool: # type: ignore[misc]
124+
def use_global_cache() -> bool:
125125
return False
126126

127127

128128
T = TypeVar("T")
129129

130130
if TYPE_CHECKING:
131-
from collections.abc import Generator, KeysView, Sequence
132-
from concurrent.futures import Future
133-
134131
from .compile_fx import _CompileFxKwargs
135132
from .graph import GraphLowering
136133
from .ir import ChoiceCaller
@@ -267,11 +264,11 @@ def get_global_cache(self) -> dict[str, Any]:
267264

268265
def lookup(
269266
self,
270-
choices: list[ChoiceCaller],
267+
choices: list["ChoiceCaller"],
271268
op: str,
272269
inputs: str,
273-
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
274-
) -> dict[ChoiceCaller, float]:
270+
benchmark: Optional[Callable[[Any], dict["ChoiceCaller", float]]],
271+
) -> dict["ChoiceCaller", float]:
275272
"""
276273
Check to see if we have benchmarked the given choice callers. For each
277274
choice caller:
@@ -617,7 +614,7 @@ def get_hash(self, obj: Any) -> str:
617614
serialized_data = self.dumps(obj)
618615
return sha256_hash(serialized_data)
619616

620-
def debug_lines(self, inp: FxGraphHashDetails) -> list[str]:
617+
def debug_lines(self, inp: "FxGraphHashDetails") -> list[str]:
621618
"""
622619
Get a printable string describing in more detail all the attributes
623620
comprising an object. Useful for debugging when one graph hashes
@@ -652,7 +649,7 @@ def get_str(obj: Any) -> str:
652649

653650

654651
def build_code_hash(
655-
roots: list[str] | None, prefix: str, hasher: hashlib._Hash
652+
roots: Optional[list[str]], prefix: str, hasher: "hashlib._Hash"
656653
) -> None:
657654
for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name):
658655
spec = lib.module_finder.find_spec(lib.name, None)
@@ -759,8 +756,8 @@ class FxGraphHashDetails:
759756
def __init__(
760757
self,
761758
gm: torch.fx.GraphModule,
762-
example_inputs: Sequence[InputType],
763-
fx_kwargs: _CompileFxKwargs,
759+
example_inputs: Sequence["InputType"],
760+
fx_kwargs: "_CompileFxKwargs",
764761
inputs_to_check: Sequence[int],
765762
) -> None:
766763
self.gm = gm
@@ -877,8 +874,8 @@ def _get_custom_pass_detail(
877874

878875
def compiled_fx_graph_hash(
879876
gm: torch.fx.GraphModule,
880-
example_inputs: Sequence[InputType],
881-
fx_kwargs: _CompileFxKwargs,
877+
example_inputs: Sequence["InputType"],
878+
fx_kwargs: "_CompileFxKwargs",
882879
inputs_to_check: Sequence[int],
883880
) -> tuple[str, list[str]]:
884881
"""
@@ -931,14 +928,14 @@ class GuardedCache(Generic[T]):
931928
"""
932929

933930
@classmethod
934-
def _get_tmp_dir_for_key(cls: type[GuardedCache[T]], _key: str) -> str:
931+
def _get_tmp_dir_for_key(cls: type["GuardedCache[T]"], _key: str) -> str:
935932
raise NotImplementedError("Implement _get_tmp_dir_for_key on parent class")
936933

937934
@classmethod
938935
def iterate_over_candidates(
939-
cls: type[GuardedCache[T]],
936+
cls: type["GuardedCache[T]"],
940937
local: bool,
941-
remote_cache: Optional[RemoteCache[JsonDataTy]],
938+
remote_cache: Optional["RemoteCache[JsonDataTy]"],
942939
key: str,
943940
) -> Generator[tuple[T, bytes], None, None]:
944941
if local:
@@ -970,10 +967,10 @@ def iterate_over_candidates(
970967

971968
@classmethod
972969
def find_guarded_entry(
973-
cls: type[GuardedCache[T]],
970+
cls: type["GuardedCache[T]"],
974971
key: str,
975972
local: bool,
976-
remote_cache: Optional[RemoteCache[JsonDataTy]],
973+
remote_cache: Optional["RemoteCache[JsonDataTy]"],
977974
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool],
978975
hints: list[int],
979976
) -> tuple[Optional[T], Optional[bytes], dict[str, str]]:
@@ -1031,7 +1028,7 @@ def find_guarded_entry(
10311028

10321029
@classmethod
10331030
def _filter_backed_symints(
1034-
cls: type[GuardedCache[T]], inputs: Sequence[InputType]
1031+
cls: type["GuardedCache[T]"], inputs: Sequence["InputType"]
10351032
) -> list[torch.SymInt]:
10361033
"""
10371034
Get the backed SymInt objects from the input list. Note that we can never
@@ -1040,7 +1037,7 @@ def _filter_backed_symints(
10401037
return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)]
10411038

10421039
@classmethod
1043-
def _get_shape_env(cls: type[GuardedCache[T]]) -> Optional[ShapeEnv]:
1040+
def _get_shape_env(cls: type["GuardedCache[T]"]) -> Optional[ShapeEnv]:
10441041
"""
10451042
Helper to get the shape env 6377 from the tracing context.
10461043
"""
@@ -1088,7 +1085,7 @@ def _get_tmp_dir() -> str:
10881085
return os.path.join(cache_dir(), "fxgraph")
10891086

10901087
@classmethod
1091-
def _get_tmp_dir_for_key(cls: type[FxGraphCache], key: str) -> str:
1088+
def _get_tmp_dir_for_key(cls: type["FxGraphCache"], key: str) -> str:
10921089
"""
10931090
Return the disk location for a given cache key.
10941091
"""
@@ -1098,7 +1095,7 @@ def _get_tmp_dir_for_key(cls: type[FxGraphCache], key: str) -> str:
10981095
def cache_hit_post_compile(
10991096
graph: CompiledFxGraph,
11001097
cache_info: dict[str, Any],
1101-
constants: CompiledFxGraphConstants,
1098+
constants: "CompiledFxGraphConstants",
11021099
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
11031100
"""
11041101
Cache specific post compile steps that need to run if we find a graph in the cache
@@ -1173,10 +1170,10 @@ def cache_hit_post_compile(
11731170
@staticmethod
11741171
def _lookup_graph(
11751172
key: str,
1176-
example_inputs: Sequence[InputType],
1173+
example_inputs: Sequence["InputType"],
11771174
local: bool,
1178-
remote_cache: Optional[RemoteCache[JsonDataTy]],
1179-
constants: CompiledFxGraphConstants,
1175+
remote_cache: Optional["RemoteCache[JsonDataTy]"],
1176+
constants: "CompiledFxGraphConstants",
11801177
evaluate_guards: Optional[
11811178
Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
11821179
] = None,
@@ -1245,10 +1242,10 @@ def _write_to_local_cache(key: str, content: bytes) -> None:
12451242
@staticmethod
12461243
def _save_graph(
12471244
key: str,
1248-
compiled_graph: OutputCode,
1249-
example_inputs: Sequence[InputType],
1245+
compiled_graph: "OutputCode",
1246+
example_inputs: Sequence["InputType"],
12501247
local: bool,
1251-
remote_cache: Optional[RemoteCache[JsonDataTy]],
1248+
remote_cache: Optional["RemoteCache[JsonDataTy]"],
12521249
) -> None:
12531250
"""
12541251
Store a serialized CompiledFxGraph on disk.
@@ -1361,8 +1358,8 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None:
13611358
@staticmethod
13621359
def prepare_key(
13631360
gm: torch.fx.GraphModule,
1364-
example_inputs: Sequence[InputType],
1365-
fx_kwargs: _CompileFxKwargs,
1361+
example_inputs: Sequence["InputType"],
1362+
fx_kwargs: "_CompileFxKwargs",
13661363
inputs_to_check: Sequence[int],
13671364
remote: bool,
13681365
) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]:
@@ -1396,7 +1393,7 @@ def prepare_key(
13961393
return (key, debug_lines), {}
13971394

13981395
@staticmethod
1399-
def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
1396+
def get_remote_cache() -> Optional["RemoteCache[JsonDataTy]"]:
14001397
"""
14011398
Attempts to load the remote cache, returns None on error.
14021399
"""
@@ -1412,15 +1409,15 @@ def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
14121409
def load_with_key(
14131410
key: str,
14141411
debug_lines: list[str],
1415-
example_inputs: Sequence[InputType],
1412+
example_inputs: Sequence["InputType"],
14161413
local: bool,
1417-
remote_cache: Optional[RemoteCache[JsonDataTy]],
1414+
remote_cache: Optional["RemoteCache[JsonDataTy]"],
14181415
is_backward: bool,
1419-
constants: CompiledFxGraphConstants,
1416+
constants: "CompiledFxGraphConstants",
14201417
evaluate_guards: Optional[
14211418
Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
14221419
] = None,
1423-
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
1420+
) -> tuple[Optional["CompiledFxGraph"], dict[str, Any]]:
14241421
"""
14251422
Lookup the graph with the given key, and return results and metadata.
14261423
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
@@ -1535,7 +1532,7 @@ class AotCodeCompiler:
15351532
@classmethod
15361533
def compile(
15371534
cls,
1538-
graph: GraphLowering,
1535+
graph: "GraphLowering",
15391536
wrapper_code: str,
15401537
kernel_code: str,
15411538
serialized_extern_kernel_nodes: Optional[str],
@@ -2252,7 +2249,7 @@ def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
22522249
raise
22532250

22542251
@classmethod
2255-
def _get_uncompiled_header(cls, device: str) -> str | None:
2252+
def _get_uncompiled_header(cls, device: str) -> Optional[str]:
22562253
"""
22572254
Given a device type, returns the path to a CPP header file to be precompiled.
22582255
Currently, this is only utilized by the cpp_wrapper classes.
@@ -2472,7 +2469,8 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType:
24722469
assert spec is not None
24732470
module = importlib.util.module_from_spec(spec)
24742471
sys.modules[module_name] = module
2475-
spec.loader.exec_module(module) # type: ignore[union-attr]
2472+
assert spec.loader is not None
2473+
spec.loader.exec_module(module)
24762474
return module
24772475

24782476
@classmethod
@@ -2592,7 +2590,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
25922590
)
25932591

25942592
@classmethod
2595-
def _get_uncompiled_header(cls, device: str) -> str | None:
2593+
def _get_uncompiled_header(cls, device: str) -> Optional[str]:
25962594
"""
25972595
Given a device type, returns the path to a CPP header file to be precompiled.
25982596
Currently, this is only utilized by the cpp_wrapper classes.
@@ -2679,7 +2677,9 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
26792677
)
26802678

26812679
@classmethod
2682-
def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[str]:
2680+
def _codegen_buffer(
2681+
cls, name: str, arg: "HalideInputSpec", cuda: bool
2682+
) -> list[str]:
26832683
assert arg.shape is not None
26842684
assert arg.stride is not None and len(arg.shape) == len(arg.stride)
26852685
assert arg.offset is not None
@@ -2713,7 +2713,7 @@ def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[st
27132713
]
27142714

27152715
@classmethod
2716-
def _codegen_glue(cls, meta: HalideMeta, headerfile: object) -> str:
2716+
def _codegen_glue(cls, meta: "HalideMeta", headerfile: object) -> str:
27172717
is_cuda = meta.is_cuda()
27182718
assert is_cuda is ("user_context" in meta.target)
27192719
assert "no_runtime" in meta.target
@@ -2821,7 +2821,7 @@ def find_header(name: str) -> str:
28212821

28222822
@classmethod
28232823
def generate_halide_async(
2824-
cls, meta: HalideMeta, source_code: str, submit_fn: Any = None
2824+
cls, meta: "HalideMeta", source_code: str, submit_fn: Any = None
28252825
) -> Callable[[], Any]:
28262826
dirpath = Path(
28272827
get_path(
@@ -2961,6 +2961,7 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
29612961
job()
29622962
except subprocess.SubprocessError as e:
29632963
if os.environ.get("HALIDE_REPRO") == "1":
2964+
cmd: list[Any]
29642965
python, script, *cmd = getattr(e, "cmd", ("", "", ""))
29652966
if os.path.basename(python).startswith("python"):
29662967
code = open(script).read()
@@ -2971,7 +2972,9 @@ class Out:
29712972
def __repr__(self) -> str:
29722973
return "out"
29732974

2974-
cmd[cmd.index("-o") + 1] = Out() # type: ignore[call-overload]
2975+
ci = cmd.index("-o")
2976+
assert isinstance(ci, int)
2977+
cmd[ci + 1] = Out()
29752978
repl = textwrap.indent(
29762979
textwrap.dedent(
29772980
f"""\
@@ -3098,7 +3101,7 @@ def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]:
30983101

30993102
def _load_triton_kernel_from_source(
31003103
kernel_name: str, source_code: str
3101-
) -> CachingAutotuner:
3104+
) -> "CachingAutotuner":
31023105
return getattr(PyCodeCache.load(source_code), kernel_name)
31033106

31043107

@@ -3557,7 +3560,7 @@ def __init__(
35573560
self.result_fn = result_fn
35583561
self.future = future
35593562

3560-
def result(self) -> Callable[..., Any]: # type: ignore[override]
3563+
def result(self) -> Callable[..., Any]:
35613564
return self.result_fn()
35623565

35633566

@@ -3566,7 +3569,7 @@ class StaticAutotunerFuture(CodeCacheFuture):
35663569
A statically launchable CachingAutotuner, loaded from TritonBundler
35673570
"""
35683571

3569-
def __init__(self, static_autotuner: CachingAutotuner) -> None:
3572+
def __init__(self, static_autotuner: "CachingAutotuner") -> None:
35703573
# Pickled version of CachingAutotuner
35713574
self.static_autotuner = static_autotuner
35723575
# This needs to be set in AsyncCompile.triton, in case
@@ -3575,10 +3578,10 @@ def __init__(self, static_autotuner: CachingAutotuner) -> None:
35753578
# since it can be very large.
35763579
self.reload_kernel_from_src: Optional[Callable[[], Any]] = None
35773580

3578-
def result(self) -> CachingAutotuner:
3581+
def result(self) -> "CachingAutotuner":
35793582
assert self.reload_kernel_from_src is not None
35803583
with dynamo_timed("StaticAutotunerFuture.warm_precompile"):
3581-
self.static_autotuner.precompile( # type: ignore[union-attr]
3584+
self.static_autotuner.precompile(
35823585
warm_cache_only=False,
35833586
reload_kernel=self.reload_kernel_from_src,
35843587
static_triton_bundle_key=None, # no need to save again

0 commit comments

Comments
 (0)
0