8000 [inductor] Clean typing in codegen/common.py and codecache.py · pytorch/pytorch@b64f4df · GitHub
[go: up one dir, main page]

Skip to content

Commit b64f4df

Browse files
committed
[inductor] Clean typing in codegen/common.py and codecache.py
ghstack-source-id: 4b6a04f Pull Request resolved: #150767
1 parent 95e0a18 commit b64f4df

File tree

3 files changed

+118
-108
lines changed

3 files changed

+118
-108
lines changed

torch/_inductor/codecache.py

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import annotations
2-
31
import base64
42
import copyreg
53
import dataclasses
@@ -24,6 +22,8 @@
2422
import threading
2523
import warnings
2624
from bisect import bisect_right
25+
from collections.abc import Generator, KeysView, Sequence
26+
from concurrent.futures import Future
2727
from copy import copy
2828
from ctypes import c_void_p, CDLL, cdll
2929
from datetime import timedelta
@@ -107,23 +107,22 @@
107107
)
108108
else:
109109

110-
def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
110+
def log_global_cache_errors(
111+
*args: Any, **kwargs: Any
112+
) -> None:
111113
pass
112114

113-
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
115+
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None:
114116
pass
115117

116-
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
118+
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None:
117119
pass
118120

119-
def use_global_cache() -> bool: # type: ignore[misc]
121+
def use_global_cache() -> bool:
120122
return False
121123

122124

123125
if TYPE_CHECKING:
124-
from collections.abc import Generator, KeysView, Sequence
125-
from concurrent.futures import Future
126-
127126
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
128127
from .graph import GraphLowering
129128
from .ir import ChoiceCaller
@@ -262,11 +261,11 @@ def get_global_cache(self) -> dict[str, Any]:
262261

263262
def lookup(
264263
self,
265-
choices: list[ChoiceCaller],
264+
choices: list["ChoiceCaller"],
266265
op: str,
267266
inputs: str,
268-
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
269-
) -> dict[ChoiceCaller, float]:
267+
benchmark: Optional[Callable[[Any], dict["ChoiceCaller", float]]],
268+
) -> dict["ChoiceCaller", float]:
270269
"""
271270
Check to see if we have benchmarked the given choice callers. For each
272271
choice caller:
@@ -612,7 +611,7 @@ def get_hash(self, obj: Any) -> str:
612611
serialized_data = self.dumps(obj)
613612
return sha256_hash(serialized_data)
614613

615-
def debug_lines(self, inp: FxGraphHashDetails) -> list[str]:
614+
def debug_lines(self, inp: "FxGraphHashDetails") -> list[str]:
616615
"""
617616
Get a printable string describing in more detail all the attributes
618617
comprising an object. Useful for debugging when one graph hashes
@@ -729,8 +728,8 @@ class FxGraphHashDetails:
729728
def __init__(
730729
self,
731730
gm: torch.fx.GraphModule,
732-
example_inputs: Sequence[InputType],
733-
fx_kwargs: _CompileFxKwargs,
731+
example_inputs: Sequence["InputType"],
732+
fx_kwargs: "_CompileFxKwargs",
734733
inputs_to_check: Sequence[int],
735734
) -> None:
736735
self.gm = gm
@@ -746,7 +745,8 @@ def __init__(
746745
if type(v) in (set, OrderedSet): # noqa: set_linter
747746
# Special case to handle set params. Python sets can't be
748747
# ordered, so sort the elements and store them in a proxy.
749-
self.fx_kwargs[k] = OrderedSetHolder(sorted(v)) # type: ignore[call-overload]
748+
assert isinstance(v, Sequence)
749+
self.fx_kwargs[k] = OrderedSetHolder(sorted(v))
750750
else:
751751
self.fx_kwargs[k] = v
752752

@@ -847,8 +847,8 @@ def _get_custom_pass_detail(
847847

848848
def compiled_fx_graph_hash(
849849
gm: torch.fx.GraphModule,
850-
example_inputs: Sequence[InputType],
851-
fx_kwargs: _CompileFxKwargs,
850+
example_inputs: Sequence["InputType"],
851+
fx_kwargs: "_CompileFxKwargs",
852852
inputs_to_check: Sequence[int],
853853
) -> tuple[str, list[str]]:
854854
"""
@@ -940,7 +940,7 @@ def _get_tmp_dir_for_key(key: str) -> str:
940940
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
941941

942942
@staticmethod
943-
def _filter_backed_symints(inputs: Sequence[InputType]) -> list[torch.SymInt]:
943+
def _filter_backed_symints(inputs: Sequence["InputType"]) -> list[torch.SymInt]:
944944
"""
945945
Get the backed SymInt objects from the input list. Note that we can never
946946
have guards that depend on unbacked symint.
@@ -960,11 +960,11 @@ def _get_shape_env() -> Optional[ShapeEnv]:
960960
@staticmethod
961961
def _lookup_graph(
962962
key: str,
963-
example_inputs: Sequence[InputType],
963+
example_inputs: Sequence["InputType"],
964964
local: bool,
965-
remote_cache: Optional[RemoteCache[JsonDataTy]],
966-
constants: CompiledFxGraphConstants,
967-
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
965+
remote_cache: Optional["RemoteCache[JsonDataTy]"],
966+
constants: "CompiledFxGraphConstants",
967+
) -> tuple[Optional["CompiledFxGraph"], dict[str, Any]]:
968968
"""
969969
Lookup a compiled graph in the cache by key. On a hit, return the
970970
deserialized CompiledFxGraph object. On a miss, return None.
@@ -976,7 +976,7 @@ def _lookup_graph(
976976
hints = [hint_int(s) for s in symints]
977977

978978
def iterate_over_candidates() -> Generator[
979-
tuple[CompiledFxGraph, bytes], None, None
979+
tuple["CompiledFxGraph", bytes], None, None
980980
]:
981981
if local:
982982
subdir = FxGraphCache._get_tmp_dir_for_key(key)
@@ -1113,10 +1113,10 @@ def _write_to_local_cache(key: str, content: bytes) -> None:
11131113
@staticmethod
11141114
def _save_graph(
11151115
key: str,
1116-
compiled_graph: OutputCode,
1117-
example_inputs: Sequence[InputType],
1116+
compiled_graph: "OutputCode",
1117+
example_inputs: Sequence["InputType"],
11181118
local: bool,
1119-
remote_cache: Optional[RemoteCache[JsonDataTy]],
1119+
remote_cache: Optional["RemoteCache[JsonDataTy]"],
11201120
) -> None:
11211121
"""
11221122
Store a serialized CompiledFxGraph on disk.
@@ -1229,8 +1229,8 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None:
12291229
@staticmethod
12301230
def prepare_key(
12311231
gm: torch.fx.GraphModule,
1232-
example_inputs: Sequence[InputType],
1233-
fx_kwargs: _CompileFxKwargs,
1232+
example_inputs: Sequence["InputType"],
1233+
fx_kwargs: "_CompileFxKwargs",
12341234
inputs_to_check: Sequence[int],
12351235
remote: bool,
12361236
) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]:
@@ -1264,7 +1264,7 @@ def prepare_key(
12641264
return (key, debug_lines), {}
12651265

12661266
@staticmethod
1267-
def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
1267+
def get_remote_cache() -> Optional["RemoteCache[JsonDataTy]"]:
12681268
"""
12691269
Attempts to load the remote cache, returns None on error.
12701270
"""
@@ -1280,12 +1280,12 @@ def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
12801280
def load_with_key(
12811281
key: str,
12821282
debug_lines: list[str],
1283-
example_inputs: Sequence[InputType],
1283+
example_inputs: Sequence["InputType"],
12841284
local: bool,
1285-
remote_cache: Optional[RemoteCache[JsonDataTy]],
1285+
remote_cache: Optional["RemoteCache[JsonDataTy]"],
12861286
is_backward: bool,
1287-
constants: CompiledFxGraphConstants,
1288-
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
1287+
constants: "CompiledFxGraphConstants",
1288+
) -> tuple[Optional["CompiledFxGraph"], dict[str, Any]]:
12891289
"""
12901290
Lookup the graph with the given key, and return results and metadata.
12911291
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
@@ -1392,7 +1392,7 @@ class AotCodeCompiler:
13921392
@classmethod
13931393
def compile(
13941394
cls,
1395-
graph: GraphLowering,
1395+
graph: "GraphLowering",
13961396
wrapper_code: str,
13971397
kernel_code: str,
13981398
serialized_extern_kernel_nodes: Optional[str],
@@ -1966,7 +1966,7 @@ def convert_arg(arg: Any) -> Any:
19661966
result = [torch.tensor([]) if r is None else r for r in result]
19671967
for i, r in enumerate(result):
19681968
assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors"
1969-
return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type]
1969+
return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result)
19701970
else:
19711971
assert isinstance(result, torch.Tensor), op + " returns a non-tensor"
19721972
return torch._C._aoti.unsafe_alloc_void_ptr_from_tensor(result)
@@ -2308,7 +2308,8 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType:
23082308
assert spec is not None
23092309
module = importlib.util.module_from_spec(spec)
23102310
sys.modules[module_name] = module
2311-
spec.loader.exec_module(module) # type: ignore[union-attr]
2311+
assert spec.loader is not None
2312+
spec.loader.exec_module(module)
23122313
return module
23132314

23142315
@classmethod
@@ -2515,7 +2516,9 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
25152516
)
25162517

25172518
@classmethod
2518-
def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[str]:
2519+
def _codegen_buffer(
2520+
cls, name: str, arg: "HalideInputSpec", cuda: bool
2521+
) -> list[str]:
25192522
assert arg.shape is not None
25202523
assert arg.stride is not None and len(arg.shape) == len(arg.stride)
25212524
assert arg.offset is not None
@@ -2549,7 +2552,7 @@ def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[st
25492552
]
25502553

25512554
@classmethod
2552-
def _codegen_glue(cls, meta: HalideMeta, headerfile: object) -> str:
2555+
def _codegen_glue(cls, meta: "HalideMeta", headerfile: object) -> str:
25532556
is_cuda = meta.is_cuda()
25542557
assert is_cuda is ("user_context" in meta.target)
25552558
assert "no_runtime" in meta.target
@@ -2657,7 +2660,7 @@ def find_header(name: str) -> str:
26572660

26582661
@classmethod
26592662
def generate_halide_async(
2660-
cls, meta: HalideMeta, source_code: str, submit_fn: Any = None
2663+
cls, meta: "HalideMeta", source_code: str, submit_fn: Any = None
26612664
) -> Callable[[], Any]:
26622665
dirpath = Path(
26632666
get_path(
@@ -2797,6 +2800,7 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
27972800
job()
27982801
except subprocess.SubprocessError as e:
27992802
if os.environ.get("HALIDE_REPRO") == "1":
2803+
cmd: list[Any]
28002804
python, script, *cmd = getattr(e, "cmd", ("", "", ""))
28012805
if os.path.basename(python).startswith("python"):
28022806
code = open(script).read()
@@ -2807,7 +2811,9 @@ class Out:
28072811
def __repr__(self) -> str:
28082812
return "out"
28092813

2810-
cmd[cmd.index("-o") + 1] = Out() # type: ignore[call-overload]
2814+
ci = cmd.index("-o")
2815+
assert isinstance(ci, int)
2816+
cmd[ci + 1] = Out()
28112817
repl = textwrap.indent(
28122818
textwrap.dedent(
28132819
f"""\
@@ -2934,7 +2940,7 @@ def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]:
29342940

29352941
def _load_triton_kernel_from_source(
29362942
kernel_name: str, source_code: str
2937-
) -> CachingAutotuner:
2943+
) -> "CachingAutotuner":
29382944
return getattr(PyCodeCache.load(source_code), kernel_name)
29392945

29402946

@@ -3349,7 +3355,7 @@ def __init__(
33493355
self.result_fn = result_fn
33503356
self.future = future
33513357

3352-
def result(self) -> Callable[..., Any]: # type: ignore[override]
3358+
def result(self) -> Callable[..., Any]:
33533359
return self.result_fn()
33543360

33553361

@@ -3358,7 +3364,7 @@ class StaticAutotunerFuture(CodeCacheFuture):
33583364
A statically launchable CachingAutotuner, loaded from TritonBundler
33593365
"""
33603366

3361-
def __init__(self, static_autotuner: CachingAutotuner) -> None:
3367+
def __init__(self, static_autotuner: "CachingAutotuner") -> None:
33623368
# Pickled version of CachingAutotuner
33633369
self.static_autotuner = static_autotuner
33643370
# This needs to be set in AsyncCompile.triton, in case
@@ -3367,10 +3373,10 @@ def __init__(self, static_autotune 7C08 r: CachingAutotuner) -> None:
33673373
# since it can be very large.
33683374
self.reload_kernel_from_src: Optional[Callable[[], Any]] = None
33693375

3370-
def result(self) -> CachingAutotuner:
3376+
def result(self) -> "CachingAutotuner":
33713377
assert self.reload_kernel_from_src is not None
33723378
with dynamo_timed("StaticAutotunerFuture.warm_precompile"):
3373-
self.static_autotuner.precompile( # type: ignore[union-attr]
3379+
self.static_autotuner.precompile(
33743380
warm_cache_only=False,
33753381
reload_kernel=self.reload_kernel_from_src,
33763382
static_triton_bundle_key=None, # no need to save again

0 commit comments

Comments
 (0)
0