8000 [inductor] Clean typing in codegen/common.py and codecache.py by rec · Pull Request #150767 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] Clean typing in codegen/common.py and codecache.py #150767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,16 @@
)
else:

def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
def log_global_cache_errors(*args: Any, **kwargs: Any) -> None:
pass

def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None:
pass

def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None:
pass

def use_global_cache() -> bool: # type: ignore[misc]
def use_global_cache() -> bool:
return False


Expand Down Expand Up @@ -2472,7 +2472,8 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType:
assert spec is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module) # type: ignore[union-attr]
assert spec.loader is not None
spec.loader.exec_module(module)
return module

@classmethod
Expand Down Expand Up @@ -2961,6 +2962,7 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
job()
except subprocess.SubprocessError as e:
if os.environ.get("HALIDE_REPRO") == "1":
cmd: list[Any]
python, script, *cmd = getattr(e, "cmd", ("", "", ""))
if os.path.basename(python).startswith("python"):
code = open(script).read()
Expand All @@ -2971,7 +2973,9 @@ class Out:
def __repr__(self) -> str:
return "out"

cmd[cmd.index("-o") + 1] = Out() # type: ignore[call-overload]
ci = cmd.index("-o")
assert isinstance(ci, int)
cmd[ci + 1] = Out()
repl = textwrap.indent(
textwrap.dedent(
f"""\
Expand Down Expand Up @@ -3581,7 +3585,7 @@ def __init__(
self.result_fn = result_fn
self.future = future

def result(self) -> Callable[..., Any]: # type: ignore[override]
def result(self) -> Callable[..., Any]:
return self.result_fn()


Expand Down
67 changes: 35 additions & 32 deletions torch/_inductor/codegen/common.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import os
import re
import tempfile
import typing
from abc import ABC, abstractmethod
from enum import auto, Enum
from itertools import chain
Expand All @@ -27,7 +26,7 @@
TYPE_CHECKING,
Union,
)
from typing_extensions import TypeVar
from typing_extensions import Self, TypeVar

import sympy

Expand Down Expand Up @@ -408,7 +407,7 @@ def get_backend_features(
if isinstance(device, torch.device):
device_type = device.type
else:
assert isinstance(device, str)
assert isinstance(device, str), type(device)
device_type = device
device = torch.device(device_type)
scheduling_ctor = get_scheduling_for_device(device_type)
Expand Down Expand Up @@ -538,7 +537,7 @@ def register_device_op_overrides(


def get_device_op_overrides(device: str) -> DeviceOpOverrides:
assert isinstance(device, str)
assert isinstance(device, str), type(device)

if not device_op_overrides_dict:
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
Expand Down Expand Up @@ -621,7 +620,7 @@ def check_dtype(
elif config.test_configs.static_cpp_dtype_assert and backend == "cpp":
from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP

assert isinstance(var, CppCSEVariable)
assert isinstance(var, CppCSEVariable), type(var)
if dtype == torch.bool:
if var.is_vec:
is_same_dt = f"IsVecMaskType<decltype({var})>::value"
Expand Down Expand Up @@ -682,9 +681,11 @@ def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]:
return None

if node.target == operator.getitem:
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
node_arg = node.args[0]
assert isinstance(node_arg, torch.fx.Node), type(node_arg)
return self.deduce_node_dtype(node_arg)

assert isinstance(node.target, str)
assert isinstance(node.target, str), type(node.target)

if node.target.startswith("masked_subblock"):
return self.deduce_node_dtype_by_subgraph(node)
Expand Down Expand Up @@ -730,8 +731,8 @@ def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
from ..loop_body import LoopBody
from ..scheduler import SchedulerNode

assert isinstance(node, SchedulerNode)
assert isinstance(node._body, LoopBody)
assert isinstance(node, SchedulerNode), type(node)
assert isinstance(node._body, LoopBody), type(node._body)
return DataTypePropagation.propagate_loopbody(node._body)


Expand Down Expand Up @@ -1428,7 +1429,7 @@ def output(self, name: str) -> str:
def make_inplace(self, input_name: str, output_name: str) -> None:
if input_name in V.graph.unaligned_buffers:
V.graph.unaligned_buffers.add(output_name)
assert output_name not in self.inplace_buffers
assert output_name not in self.inplace_buffers, output_name
if input_name in self.inplace_buffers:
buf = self.inplace_buffers[input_name]
assert not isinstance(buf, RemovedArg)
Expand Down Expand Up @@ -1490,7 +1491,7 @@ def workspace(self, nbytes: sympy.Expr, zero_fill: bool) -> tuple[str, int]:
assert (
existing_arg.inner_name != arg.inner_name
and existing_arg.outer_name != arg.outer_name
)
), existing_arg
self.workspace_args.append(arg)
return arg.inner_name, 0

Expand Down Expand Up @@ -1518,7 +1519,7 @@ def semaphores(self, min_size: sympy.Expr) -> str:
)
for existing_arg in self.workspace_args:
if existing_arg.inner_name == arg.inner_name:
assert arg == existing_arg
assert arg == existing_arg, (arg, existing_arg)
self.workspace_args.append(arg)
return arg.inner_name

Expand Down Expand Up @@ -1618,7 +1619,7 @@ def python_argdefs(
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
arg_defs: list[ArgName] = []
call_args: list[str] = []
arg_types: list[torch.dtype] = []
arg_types: list[Any] = []
precompile_args: list[KernelArgType] = []
for inplaced in unique(self.inplace_buffers.values()):
if isinstance(inplaced, RemovedArg):
Expand Down Expand Up @@ -1651,7 +1652,7 @@ def python_argdefs(
for outer, inner in self.sizevars.items():
arg_defs.append(ArgName(inner))
call_args.append(outer)
arg_types.append(type(outer)) # type: ignore[arg-type]
arg_types.append(type(outer))
precompile_args.append(SizeArg(inner, outer))
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
Expand Down Expand Up @@ -1686,7 +1687,7 @@ def is_removed(self, name: str) -> bool:
# after you do a call into this kernel, which buffers actually contain
# updated data? Modeled off of python_argdefs.
def live_output_buffers(self) -> OrderedSet[str]:
live_outs = OrderedSet() # type: ignore[var-annotated]
live_outs = OrderedSet[str]()
for inplaced in unique(self.inplace_buffers.values()):
if isinstance(inplaced, RemovedArg):
continue
Expand All @@ -1712,7 +1713,7 @@ def __init__(
dtype: Optional[torch.dtype] = None,
):
super().__init__()
assert isinstance(bounds, ValueRanges)
assert isinstance(bounds, ValueRanges), type(bounds)
self.name = name
self.bounds = bounds
self.use_count = 1 # track how many times this expression is used
Expand Down Expand Up @@ -1782,7 +1783,7 @@ def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
else:
self._cache = {}

def clone(self) -> typing.Self:
def clone(self) -> Self:
return type(self)(
prefix=self.prefix,
suffix=self.suffix,
Expand All @@ -1793,7 +1794,7 @@ def clone(self) -> typing.Self:
reduction_cache=self.reduction_cache,
)

def scoped_copy(self) -> typing.Self:
def scoped_copy(self) -> Self:
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
new_cse = self.clone()
new_cse._cache = ScopedDict(self._cache)
Expand Down Expand Up @@ -1918,7 +1919,7 @@ def __init__(self) -> None:
super().__init__()
self.exit_stack = contextlib.ExitStack()

def __enter__(self) -> typing.Self:
def __enter__(self) -> Self:
self.exit_stack.__enter__()
return self

Expand Down Expand Up @@ -2084,7 +2085,7 @@ def indirect_assert(
) -> str:
if isinstance(var, CSEVariable):
var = str(var)
assert isinstance(var, str)
assert isinstance(var, str), type(var)
assert lower is None or isinstance(lower, str)
assert upper is None or isinstance(upper, str)
if lower and upper:
Expand Down Expand Up @@ -2113,7 +2114,7 @@ def check_bounds(
def index_to_str(self, index: sympy.Expr) -> str:
raise NotImplementedError

def __enter__(self) -> typing.Self:
def __enter__(self) -> Self:
super().__enter__()
assert self.overrides
self.exit_stack.enter_context(
Expand Down Expand Up @@ -2184,7 +2185,7 @@ def rename_indexing(
# adds the necessary kernel args for index expressions
# and renames variables in index expressions to kernel arg names
if isinstance(index, (list, tuple)):
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
return [self.rename_indexing(x) for x in index]
index = V.graph.sizevars.simplify(index)
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
replacements = {
Expand Down Expand Up @@ -2362,7 +2363,7 @@ def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
bounds = self._bound_variable(name, *args, **kwargs)

value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
value = getattr(self.parent_handler, name)(*args, **kwargs)
dtype_handler = DtypePropagationOpsHandler()

backend = get_current_backend()
Expand All @@ -2387,8 +2388,8 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) ->
def do_cse(v: str) -> CSEVariable:
# we tree_map over the output, so we need to fetch corresponding dtype
nonlocal output_idx
var_dtype: torch.dtype = (
output_dtype[output_idx] # type: ignore[assignment]
var_dtype: Optional[torch.dtype] = (
output_dtype[output_idx]
if isinstance(output_dtype, (list, tuple))
else output_dtype
)
Expand All @@ -2411,6 +2412,7 @@ def do_cse(v: str) -> CSEVariable:
config.test_configs.runtime_triton_dtype_assert
or config.test_configs.static_cpp_dtype_assert
):
assert var_dtype is not None
check_dtype(V.kernel.compute, csevar, var_dtype)
return csevar

Expand All @@ -2433,7 +2435,9 @@ def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[A

fx_node = V.interpreter.current_node
if fx_node.target == name and self.kernel.node_to_bounds is not None:
assert isinstance(self.kernel.node_to_bounds, dict)
assert isinstance(self.kernel.node_to_bounds, dict), type(
self.kernel.node_to_bounds
)
return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown())
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
# These create lots of inner strings. We would need to compute the bounds at the ops
Expand Down Expand Up @@ -2468,14 +2472,14 @@ def indirect_indexing(
) -> sympy.Symbol:
if isinstance(size, int):
size = sympy.Integer(size)
assert isinstance(size, sympy.Expr), size
assert isinstance(size, sympy.Expr), (type(size), size)
# Skip CSE since this doesn't return an expression

if var.bounds.lower < 0: # type: ignore[operator]
if var.bounds.lower < 0:
if wrap_neg:
stm = ops.add(var, ops.index_expr(size, torch.long))
# Mixed negative and non-negative
if var.bounds.upper >= 0: # type: ignore[operator]
if var.bounds.upper >= 0:
lt = ops.lt(var, 0)
stm = ops.where(lt, stm, var)
else:
Expand All @@ -2492,7 +2496,7 @@ def indirect_indexing(
neg_bounds.lower + size, neg_bounds.upper + size
)
# We don't have a good way of representing the empty range
if var.bounds.upper >= 0: # type: ignore[operator]
if var.bounds.upper >= 0:
pos = var.bounds & ValueRanges(0, int_oo)
new_bounds = new_bounds | pos

Expand Down Expand Up @@ -2544,8 +2548,7 @@ def store(
if mode is None:
self._update_store_cache(name, value)
if name not in V.graph.removed_buffers:
return self.kernel.store(name, index, value, mode=mode)
return None # type: ignore[return-value]
self.kernel.store(name, index, value, mode=mode)

def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
self.kernel.store_buffer_names.add(name)
Expand Down
Loading
0