8000 [inductor] Clean typing in codegen/common.py and codecache.py (#150767) · pytorch/pytorch@8568dbc · GitHub
[go: up one dir, main page]

Skip to content

Commit 8568dbc

Browse files
recpytorchmergebot
authored andcommitted
[inductor] Clean typing in codegen/common.py and codecache.py (#150767)
Pull Request resolved: #150767 Approved by: https://github.com/aorenste
1 parent 27f7b65 commit 8568dbc

File tree

2 files changed

+46
-39
lines changed

2 files changed

+46
-39
lines changed

torch/_inductor/codecache.py

8000 Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,16 @@
112112
)
113113
else:
114114

115-
def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
115+
def log_global_cache_errors(*args: Any, **kwargs: Any) -> None:
116116
pass
117117

118-
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
118+
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None:
119119
pass
120120

121-
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
121+
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None:
122122
pass
123123

124-
def use_global_cache() -> bool: # type: ignore[misc]
124+
def use_global_cache() -> bool:
125125
return False
126126

127127

@@ -2451,7 +2451,8 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType:
24512451
assert spec is not None
24522452
module = importlib.util.module_from_spec(spec)
24532453
sys.modules[module_name] = module
2454-
spec.loader.exec_module(module) # type: ignore[union-attr]
2454+
assert spec.loader is not None
2455+
spec.loader.exec_module(module)
24552456
return module
24562457

24572458
@classmethod
@@ -2945,6 +2946,7 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
29452946
job()
29462947
except subprocess.SubprocessError as e:
29472948
if os.environ.get("HALIDE_REPRO") == "1":
2949+
cmd: list[Any]
29482950
python, script, *cmd = getattr(e, "cmd", ("", "", ""))
29492951
if os.path.basename(python).startswith("python"):
29502952
code = open(script).read()
@@ -2955,7 +2957,9 @@ class Out:
29552957
def __repr__(self) -> str:
29562958
return "out"
29572959

2958-
cmd[cmd.index("-o") + 1] = Out() # type: ignore[call-overload]
2960+
ci = cmd.index("-o")
2961+
assert isinstance(ci, int)
2962+
cmd[ci + 1] = Out()
29592963
repl = textwrap.indent(
29602964
textwrap.dedent(
29612965
f"""\
@@ -3565,7 +3569,7 @@ def __init__(
35653569
self.result_fn = result_fn
35663570
self.future = future
35673571

3568-
def result(self) -> Callable[..., Any]: # type: ignore[override]
3572+
def result(self) -> Callable[..., Any]:
35693573
return self.result_fn()
35703574

35713575

9E88 torch/_inductor/codegen/common.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import os
1313
import re
1414
import tempfile
15-
import typing
1615
from abc import ABC, abstractmethod
1716
from enum import auto, Enum
1817
from itertools import chain
@@ -27,7 +26,7 @@
2726
TYPE_CHECKING,
2827
Union,
2928
)
30-
from typing_extensions import TypeVar
29+
from typing_extensions import Self, TypeVar
3130

3231
import sympy
3332

@@ -408,7 +407,7 @@ def get_backend_features(
408407
if isinstance(device, torch.device):
409408
device_type = device.type
410409
else:
411-
assert isinstance(device, str)
410+
assert isinstance(device, str), type(device)
412411
device_type = device
413412
device = torch.device(device_type)
414413
scheduling_ctor = get_scheduling_for_device(device_type)
@@ -538,7 +537,7 @@ def register_device_op_overrides(
538537

539538

540539
def get_device_op_overrides(device: str) -> DeviceOpOverrides:
541-
assert isinstance(device, str)
540+
assert isinstance(device, str), type(device)
542541

543542
if not device_op_overrides_dict:
544543
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
@@ -621,7 +620,7 @@ def check_dtype(
621620
elif config.test_configs.static_cpp_dtype_assert and backend == "cpp":
622621
from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP
623622

624-
assert isinstance(var, CppCSEVariable)
623+
assert isinstance(var, CppCSEVariable), type(var)
625624
if dtype == torch.bool:
626625
if var.is_vec:
627626
is_same_dt = f"IsVecMaskType<decltype({var})>::value"
@@ -682,9 +681,11 @@ def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]:
682681
return None
683682

684683
if node.target == operator.getitem:
685-
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
684+
node_arg = node.args[0]
685+
assert isinstance(node_arg, torch.fx.Node), type(node_arg)
686+
return self.deduce_node_dtype(node_arg)
686687

687-
assert isinstance(node.target, str)
688+
assert isinstance(node.target, str), type(node.target)
688689

689690
if node.target.startswith("masked_subblock"):
690691
return self.deduce_node_dtype_by_subgraph(node)
@@ -730,8 +731,8 @@ def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
730731
from ..loop_body import LoopBody
731732
from ..scheduler import SchedulerNode
732733

733-
assert isinstance(node, SchedulerNode)
734-
assert isinstance(node._body, LoopBody)
734+
assert isinstance(node, SchedulerNode), type(node)
735+
assert isinstance(node._body, LoopBody), type(node._body)
735736
return DataTypePropagation.propagate_loopbody(node._body)
736737

737738

@@ -1428,7 +1429,7 @@ def output(self, name: str) -> str:
14281429
def make_inplace(self, input_name: str, output_name: str) -> None:
14291430
if input_name in V.graph.unaligned_buffers:
14301431
V.graph.unaligned_buffers.add(output_name)
1431-
assert output_name not in self.inplace_buffers
1432+
assert output_name not in self.inplace_buffers, output_name
14321433
if input_name in self.inplace_buffers:
14331434
buf = self.inplace_buffers[input_name]
14341435
assert not isinstance(buf, RemovedArg)
@@ -1490,7 +1491,7 @@ def workspace(self, nbytes: sympy.Expr, zero_fill: bool) -> tuple[str, int]:
14901491
assert (
14911492
existing_arg.inner_name != arg.inner_name
14921493
and existing_arg.outer_name != arg.outer_name
1493-
)
1494+
), existing_arg
14941495
self.workspace_args.append(arg)
14951496
return arg.inner_name, 0
14961497

@@ -1518,7 +1519,7 @@ def semaphores(self, min_size: sympy.Expr) -> str:
15181519
)
15191520
for existing_arg in self.workspace_args:
15201521
if existing_arg.inner_name == arg.inner_name:
1521-
assert arg == existing_arg
1522+
assert arg == existing_arg, (arg, existing_arg)
15221523
self.workspace_args.append(arg)
15231524
return arg.inner_name
15241525

@@ -1618,7 +1619,7 @@ def python_argdefs(
16181619
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
16191620
arg_defs: list[ArgName] = []
16201621
call_args: list[str] = []
1621-
arg_types: list[torch.dtype] = []
1622+
arg_types: list[Any] = []
16221623
precompile_args: list[KernelArgType] = []
16231624
for inplaced in unique(self.inplace_buffers.values()):
16241625
if isinstance(inplaced, RemovedArg):
@@ -1651,7 +1652,7 @@ def python_argdefs(
16511652
for outer, inner in self.sizevars.items():
16521653
arg_defs.append(ArgName(inner))
16531654
call_args.append(outer)
1654-
arg_types.append(type(outer)) # type: ignore[arg-type]
1655+
arg_types.append(type(outer))
16551656
precompile_args.append(SizeArg(inner, outer))
16561657
if V.graph.wrapper_code:
16571658
V.graph.wrapper_code.ensure_size_computed(outer)
@@ -1686,7 +1687,7 @@ def is_removed(self, name: str) -> bool:
16861687
# after you do a call into this kernel, which buffers actually contain
16871688
# updated data? Modeled off of python_argdefs.
16881689
def live_output_buffers(self) -> OrderedSet[str]:
1689-
live_outs = OrderedSet() # type: ignore[var-annotated]
1690+
live_outs = OrderedSet[str]()
16901691
for inplaced in unique(self.inplace_buffers.values()):
16911692
if isinstance(inplaced, RemovedArg):
16921693
continue
@@ -1712,7 +1713,7 @@ def __init__(
17121713
dtype: Optional[torch.dtype] = None,
17131714
):
17141715
super().__init__()
1715-
assert isinstance(bounds, ValueRanges)
1716+
assert isinstance(bounds, ValueRanges), type(bounds)
17161717
self.name = name
17171718
self.bounds = bounds
17181719
self.use_count = 1 # track how many times this expression is used
@@ -1782,7 +1783,7 @@ def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
17821783
else:
17831784
self._cache = {}
17841785

1785-
def clone(self) -> typing.Self:
1786+
def clone(self) -> Self:
17861787
return type(self)(
17871788
prefix=self.prefix,
17881789
suffix=self.suffix,
@@ -1793,7 +1794,7 @@ def clone(self) -> typing.Self:
17931794
reduction_cache=self.reduction_cache,
17941795
)
17951796

1796-
def scoped_copy(self) -> typing.Self:
1797+
def scoped_copy(self) -> Self:
17971798
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
17981799
new_cse = self.clone()
17991800
new_cse._cache = ScopedDict(self._cache)
@@ -1918,7 +1919,7 @@ def __init__(self) -> None:
19181919
super().__init__()
19191920
self.exit_stack = contextlib.ExitStack()
19201921

1921-
def __enter__(self) -> typing.Self:
1922+
def __enter__(self) -> Self:
19221923
self.exit_stack.__enter__()
19231924
return self
19241925

@@ -2084,7 +2085,7 @@ def indirect_assert(
20842085
) -> str:
20852086
if isinstance(var, CSEVariable):
20862087
var = str(var)
2087-
assert isinstance(var, str)
2088+
assert isinstance(var, str), type(var)
20882089
assert lower is None or isinstance(lower, str)
20892090
assert upper is None or isinstance(upper, str)
20902091
if lower and upper:
@@ -2113,7 +2114,7 @@ def check_bounds(
21132114
def index_to_str(self, index: sympy.Expr) -> str:
21142115
raise NotImplementedError
21152116

2116-
def __enter__(self) -> typing.Self:
2117+
def __enter__(self) -> Self:
21172118
super().__enter__()
21182119
assert self.overrides
21192120
self.exit_stack.enter_context(
@@ -2184,7 +2185,7 @@ def rename_indexing(
21842185
# adds the necessary kernel args for index expressions
21852186
# and renames variables in index expressions to kernel arg names
21862187
if isinstance(index, (list, tuple)):
2187-
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
2188+
return [self.rename_indexing(x) for x in index]
21882189
index = V.graph.sizevars.simplify(index)
21892190
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
21902191
replacements = {
@@ -2362,7 +2363,7 @@ def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
23622363
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
23632364
bounds = self._bound_variable(name, *args, **kwargs)
23642365

2365-
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
2366+
value = getattr(self.parent_handler, name)(*args, **kwargs)
23662367
dtype_handler = DtypePropagationOpsHandler()
23672368

23682369
backend = get_current_backend()
@@ -2387,8 +2388,8 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) ->
23872388
def do_cse(v: str) -> CSEVariable:
23882389
# we tree_map over the output, so we need to fetch corresponding dtype
23892390
nonlocal output_idx
2390-
var_dtype: torch.dtype = (
2391-
output_dtype[output_idx] # type: ignore[assignment]
2391+
var_dtype: Optional[torch.dtype] = (
2392+
output_dtype[output_idx]
23922393
if isinstance(output_dtype, (list, tuple))
23932394
else output_dtype
23942395
)
@@ -2411,6 +2412,7 @@ def do_cse(v: str) -> CSEVariable:
24112412
config.test_configs.runtime_triton_dtype_assert
24122413
or config.test_configs.static_cpp_dtype_assert
24132414
):
2415+
assert var_dtype is not None
24142416
check_dtype(V.kernel.compute, csevar, var_dtype)
24152417
return csevar
24162418

@@ -2433,7 +2435,9 @@ def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[A
24332435

24342436
fx_node = V.interpreter.current_node
24352437
if fx_node.target == name and self.kernel.node_to_bounds is not None:
2436-
assert isinstance(self.kernel.node_to_bounds, dict)
2438+
assert isinstance(self.kernel.node_to_bounds, dict), type(
2439+
self.kernel.node_to_bounds
2440+
)
24372441
return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown())
24382442
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
24392443
# These create lots of inner strings. We would need to compute the bounds at the ops
@@ -2468,14 +2472,14 @@ def indirect_indexing(
24682472
) -> sympy.Symbol:
24692473
if isinstance(size, int):
24702474
size = sympy.Integer(size)
2471-
assert isinstance(size, sympy.Expr), size
2475+
assert isinstance(size, sympy.Expr), (type(size), size)
24722476
# Skip CSE since this doesn't return an expression
24732477

2474-
if var.bounds.lower < 0: # type: ignore[operator]
2478+
if var.bounds.lower < 0:
24752479
if wrap_neg:
24762480
stm = ops.add(var, ops.index_expr(size, torch.long))
24772481
# Mixed negative and non-negative
2478-
if var.bounds.upper >= 0: # type: ignore[operator]
2482+
if var.bounds.upper >= 0:
24792483
lt = ops.lt(var, 0)
24802484
stm = ops.where(lt, stm, var)
24812485
else:
@@ -2492,7 +2496,7 @@ def indirect_indexing(
24922496
neg_bounds.lower + size, neg_bounds.upper + size
24932497
)
24942498
# We don't have a good way of representing the empty range
2495-
if var.bounds.upper >= 0: # type: ignore[operator]
2499+
if var.bounds.upper >= 0:
24962500
pos = var.bounds & ValueRanges(0, int_oo)
24972501
new_bounds = new_bounds | pos
24982502

@@ -2544,8 +2548,7 @@ def store(
25442548
if mode is None:
25452549
self._update_store_cache(name, value)
25462550
if name not in V.graph.removed_buffers:
2547-
return self.kernel.store(name, index, value, mode=mode)
2548-
return None # type: ignore[return-value]
2551+
self.kernel.store(name, index, value, mode=mode)
25492552

25502553
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
25512554
self.kernel.store_buffer_names.add(name)

0 commit comments

Comments
 (0)
0