From 0f6ee2fde9d359eeaf69b7d03d3da6341524c898 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Mon, 7 Apr 2025 12:07:43 +0000 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- torch/_inductor/codegen/common.py | 72 +++++++++++++++++-------------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 7fce40e869efb..1a95546a805c5 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import contextlib import dataclasses import enum @@ -10,6 +8,7 @@ import operator import re import typing +from collections.abc import Iterator, MutableMapping, Sequence from enum import auto, Enum from itertools import chain from typing import ( @@ -58,8 +57,6 @@ if TYPE_CHECKING: - from collections.abc import Iterator, MutableMapping, Sequence - from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode from ..loop_body import LoopBody from ..scheduler import BaseScheduling, Scheduler, SchedulerNode @@ -89,7 +86,7 @@ class WorkspaceZeroMode(enum.Enum): ZERO_PER_GRAPH = 2 # must be re-zeroed by kernel @staticmethod - def combine(a: WorkspaceZeroMode, b: WorkspaceZeroMode) -> WorkspaceZeroMode: + def combine(a: "WorkspaceZeroMode", b: "WorkspaceZeroMode") -> "WorkspaceZeroMode": if a == b or b == WorkspaceZeroMode.UNINITIALIZED: return a if a == WorkspaceZeroMode.UNINITIALIZED: @@ -97,7 +94,7 @@ def combine(a: WorkspaceZeroMode, b: WorkspaceZeroMode) -> WorkspaceZeroMode: raise NotImplementedError(f"WorkspaceZeroMode.combine({a!r}, {b!r})") @staticmethod - def from_bool(zero_fill: bool) -> WorkspaceZeroMode: + def from_bool(zero_fill: bool) -> "WorkspaceZeroMode": if zero_fill: return WorkspaceZeroMode.ZERO_ON_CALL return WorkspaceZeroMode.UNINITIALIZED @@ -128,13 +125,13 @@ def unique_name(prefix: str = "workspace_") -> str: return f"{prefix}{next(V.graph.workspace_id)}" @staticmethod - def can_join(a: WorkspaceArg, b: WorkspaceArg) -> bool: + def can_join(a: "WorkspaceArg", b: "WorkspaceArg") -> bool: return ( a.inner_name == b.inner_name and a.dtype == b.dtype and a.device == b.device ) @staticmethod - def join(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: + def join(a: "WorkspaceArg", b: "WorkspaceArg") -> "WorkspaceArg": return WorkspaceArg( count=a.count + b.count, zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode), @@ -145,7 +142,7 @@ def join(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: ) @staticmethod - def maximum(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: + def maximum(a: "WorkspaceArg", b: "WorkspaceArg") -> "WorkspaceArg": assert ( a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name ) @@ -167,7 +164,7 @@ def get_device(self) -> torch.device: def get_dtype(self) -> torch.dtype: return self.dtype - def get_layout(self) -> FixedLayout: + def get_layout(self) -> "FixedLayout": from ..ir import FixedLayout return FixedLayout( @@ -178,7 +175,7 @@ def get_layout(self) -> FixedLayout: ) @property - def layout(self) -> FixedLayout: + def layout(self) -> "FixedLayout": return self.get_layout() get_output_spec = get_layout @@ -548,14 +545,14 @@ def deduce_output_dtype_by_name( "store_reduction", ): buf_name = args[1] - return V.graph.get_dtype(buf_name) # type: ignore[arg-type] + return V.graph.get_dtype(buf_name) # Hype: ignore[arg-type] elif op_name == "to_dtype_bitcast": return kwargs["dtype"] if "dtype" in kwargs else args[-2] return None def check_dtype( - buffer: IndentedBuffer, var: CSEVariableType, dtype: torch.dtype + buffer: IndentedBuffer, var: "CSEVariableType", dtype: torch.dtype ) -> None: backend = get_current_backend() if config.test_configs.runtime_triton_dtype_assert and backend == "triton": @@ -580,7 +577,7 @@ def check_dtype( class DataTypePropagation: - def __init__(self, body: LoopBody) -> None: + def __init__(self, body: "LoopBody") -> None: self.body = body self.graphs: dict[Union[Callable[..., Any], str], Any] = { "root": body.root_block.graph @@ -624,7 +621,11 @@ def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]: return None if node.target == operator.getitem: - return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type] + from ..ir import IRNode + + node_arg = node.args[0] + assert isinstance(node_arg, IRNode) + return self.deduce_node_dtype(node_arg) assert isinstance(node.target, str) @@ -664,11 +665,11 @@ def propagate(self) -> Optional[torch.dtype]: return self.propagate_graph(self.graphs["root"]) @classmethod - def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]: + def propagate_loopbody(cls, body: "LoopBody") -> Optional[torch.dtype]: return cls(body).propagate() @classmethod - def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]: + def propagate_scheduler_node(cls, node: "SchedulerNode") -> Optional[torch.dtype]: from ..loop_body import LoopBody from ..scheduler import SchedulerNode @@ -1288,7 +1289,7 @@ def __call__(self) -> Optional[str]: return self.line return None - def _new_line(self, line: str) -> DeferredLine: + def _new_line(self, line: str) -> "DeferredLine": return DeferredLine(self.name, line) @@ -1580,7 +1581,7 @@ def python_argdefs( ) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]: arg_defs: list[ArgName] = [] call_args: list[str] = [] - arg_types: list[torch.dtype] = [] + arg_types: list[Any] = [] precompile_args: list[KernelArgType] = [] for inplaced in unique(self.inplace_buffers.values()): if isinstance(inplaced, RemovedArg): @@ -1613,7 +1614,7 @@ def python_argdefs( for outer, inner in self.sizevars.items(): arg_defs.append(ArgName(inner)) call_args.append(outer) - arg_types.append(type(outer)) # type: ignore[arg-type] + arg_types.append(type(outer)) # Hype: ignore[arg-type] precompile_args.append(SizeArg(inner, outer)) if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) @@ -1648,7 +1649,7 @@ def is_removed(self, name: str) -> bool: # after you do a call into this kernel, which buffers actually contain # updated data? Modeled off of python_argdefs. def live_output_buffers(self) -> OrderedSet[str]: - live_outs = OrderedSet() # type: ignore[var-annotated] + live_outs = OrderedSet[str]() # Hype: ignore[var-annotated] for inplaced in unique(self.inplace_buffers.values()): if isinstance(inplaced, RemovedArg): continue @@ -1928,7 +1929,7 @@ def __init__( self.kernel_name: Optional[str] = None @contextlib.contextmanager - def set_current_node(self, node: SchedulerNode) -> Iterator[None]: + def set_current_node(self, node: "SchedulerNode") -> Iterator[None]: prior = self.current_node self.current_node = node self.node_to_bounds = node._body.bounds().get_bounds() @@ -2146,7 +2147,9 @@ def rename_indexing( # adds the necessary kernel args for index expressions # and renames variables in index expressions to kernel arg names if isinstance(index, (list, tuple)): - return [self.rename_indexing(x) for x in index] # type: ignore[return-value] + return [ + self.rename_indexing(x) for x in index + ] # Hype: ignore[return-value] index = V.graph.sizevars.simplify(index) sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) replacements = { @@ -2166,7 +2169,7 @@ def rename_indexing( def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable: return CSEVariable(*args, **kwargs) - def arg_name(self, node: IRNode) -> Optional[str]: + def arg_name(self, node: "IRNode") -> Optional[str]: """ Returns arg name of a given input or output node. """ @@ -2260,7 +2263,7 @@ def __str__(self) -> str: @staticmethod def _fake_get_dtype( - fake_outs: Union[list[Buffer], Buffer], + fake_outs: Union[list["Buffer"], "Buffer"], ) -> Callable[[str], torch.dtype]: _get_dtype_real = V.graph.get_dtype if isinstance(fake_outs, (list, tuple)): @@ -2302,7 +2305,7 @@ def maybe_append_choice( ) return e - def generate(self, **kwargs: Any) -> ChoiceCaller: + def generate(self, **kwargs: Any) -> "ChoiceCaller": """ Generates a ChoiceCaller instance from the given arguments. """ @@ -2324,7 +2327,9 @@ def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]): def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: bounds = self._bound_variable(name, *args, **kwargs) - value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + value = getattr(self.parent_handler, name)( + *args, **kwargs + ) # Hype: ignore[has-type] dtype_handler = DtypePropagationOpsHandler() backend = get_current_backend() @@ -2349,8 +2354,8 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> def do_cse(v: str) -> CSEVariable: # we tree_map over the output, so we need to fetch corresponding dtype nonlocal output_idx - var_dtype: torch.dtype = ( - output_dtype[output_idx] # type: ignore[assignment] + var_dtype: Optional[torch.dtype] = ( + output_dtype[output_idx] # Hype: ignore[assignment] if isinstance(output_dtype, (list, tuple)) else output_dtype ) @@ -2373,6 +2378,7 @@ def do_cse(v: str) -> CSEVariable: config.test_configs.runtime_triton_dtype_assert or config.test_configs.static_cpp_dtype_assert ): + assert var_dtype is not None check_dtype(V.kernel.compute, csevar, var_dtype) return csevar @@ -2429,11 +2435,11 @@ def indirect_indexing( assert isinstance(size, sympy.Expr), size # Skip CSE since this doesn't return an expression - if var.bounds.lower < 0: # type: ignore[operator] + if var.bounds.lower < 0: # Hype: ignore[operator] if wrap_neg: stm = ops.add(var, ops.index_expr(size, torch.long)) # Mixed negative and non-negative - if var.bounds.upper >= 0: # type: ignore[operator] + if var.bounds.upper >= 0: # Hype: ignore[operator] lt = ops.lt(var, 0) stm = ops.where(lt, stm, var) else: @@ -2450,7 +2456,7 @@ def indirect_indexing( neg_bounds.lower + size, neg_bounds.upper + size ) # We don't have a good way of representing the empty range - if var.bounds.upper >= 0: # type: ignore[operator] + if var.bounds.upper >= 0: # Hype: ignore[operator] pos = var.bounds & ValueRanges(0, int_oo) new_bounds = new_bounds | pos @@ -2503,7 +2509,7 @@ def store( self._update_store_cache(name, value) if name not in V.graph.removed_buffers: return self.kernel.store(name, index, value, mode=mode) - return None # type: ignore[return-value] + return None # Hype: ignore[return-value] def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: self.kernel.store_buffer_names.add(name) From dd35a9a586358269025a0efb3f26883397ad1da0 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Mon, 7 Apr 2025 14:56:05 +0000 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- torch/_inductor/codegen/common.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 841591cb00714..d9b5929cc4a8f 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -546,7 +546,7 @@ def deduce_output_dtype_by_name( "store_reduction", ): buf_name = args[1] - return V.graph.get_dtype(buf_name) # Hype: ignore[arg-type] + return V.graph.get_dtype(buf_name) elif op_name == "to_dtype_bitcast": return kwargs["dtype"] if "dtype" in kwargs else args[-2] return None @@ -1615,7 +1615,7 @@ def python_argdefs( for outer, inner in self.sizevars.items(): arg_defs.append(ArgName(inner)) call_args.append(outer) - arg_types.append(type(outer)) # Hype: ignore[arg-type] + arg_types.append(type(outer)) precompile_args.append(SizeArg(inner, outer)) if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) @@ -1650,7 +1650,7 @@ def is_removed(self, name: str) -> bool: # after you do a call into this kernel, which buffers actually contain # updated data? Modeled off of python_argdefs. def live_output_buffers(self) -> OrderedSet[str]: - live_outs = OrderedSet[str]() # Hype: ignore[var-annotated] + live_outs = OrderedSet[str]() for inplaced in unique(self.inplace_buffers.values()): if isinstance(inplaced, RemovedArg): continue @@ -2147,9 +2147,7 @@ def rename_indexing( # adds the necessary kernel args for index expressions # and renames variables in index expressions to kernel arg names if isinstance(index, (list, tuple)): - return [ - self.rename_indexing(x) for x in index - ] # Hype: ignore[return-value] + return [self.rename_indexing(x) for x in index] index = V.graph.sizevars.simplify(index) sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) replacements = { @@ -2327,9 +2325,7 @@ def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]): def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: bounds = self._bound_variable(name, *args, **kwargs) - value = getattr(self.parent_handler, name)( - *args, **kwargs - ) # Hype: ignore[has-type] + value = getattr(self.parent_handler, name)(*args, **kwargs) dtype_handler = DtypePropagationOpsHandler() backend = get_current_backend() @@ -2355,7 +2351,7 @@ def do_cse(v: str) -> CSEVariable: # we tree_map over the output, so we need to fetch corresponding dtype nonlocal output_idx var_dtype: Optional[torch.dtype] = ( - output_dtype[output_idx] # Hype: ignore[assignment] + output_dtype[output_idx] if isinstance(output_dtype, (list, tuple)) else output_dtype ) @@ -2435,11 +2431,11 @@ def indirect_indexing( assert isinstance(size, sympy.Expr), size # Skip CSE since this doesn't return an expression - if var.bounds.lower < 0: # Hype: ignore[operator] + if var.bounds.lower < 0: if wrap_neg: stm = ops.add(var, ops.index_expr(size, torch.long)) # Mixed negative and non-negative - if var.bounds.upper >= 0: # Hype: ignore[operator] + if var.bounds.upper >= 0: lt = ops.lt(var, 0) stm = ops.where(lt, stm, var) else: @@ -2456,7 +2452,7 @@ def indirect_indexing( neg_bounds.lower + size, neg_bounds.upper + size ) # We don't have a good way of representing the empty range - if var.bounds.upper >= 0: # Hype: ignore[operator] + if var.bounds.upper >= 0: pos = var.bounds & ValueRanges(0, int_oo) new_bounds = new_bounds | pos @@ -2508,8 +2504,7 @@ def store( if mode is None: self._update_store_cache(name, value) if name not in V.graph.removed_buffers: - return self.kernel.store(name, index, value, mode=mode) - return None # Hype: ignore[return-value] + self.kernel.store(name, index, value, mode=mode) def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: self.kernel.store_buffer_names.add(name) From d99e1dbb67105ccbe33d0f6a045212e05a9dd8ea Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Mon, 7 Apr 2025 15:50:46 +0000 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- torch/_inductor/codecache.py | 98 +++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 46 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index e331badcff34f..a8ff222616b12 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import base64 import copyreg import dataclasses @@ -24,6 +22,8 @@ import threading import warnings from bisect import bisect_right +from collections.abc import Generator, KeysView, Sequence +from concurrent.futures import Future from copy import copy from ctypes import c_void_p, CDLL, cdll from datetime import timedelta @@ -107,23 +107,22 @@ ) else: - def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + def log_global_cache_errors( + *args: Any, **kwargs: Any + ) -> None: pass - def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: pass - def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: pass - def use_global_cache() -> bool: # type: ignore[misc] + def use_global_cache() -> bool: return False if TYPE_CHECKING: - from collections.abc import Generator, KeysView, Sequence - from concurrent.futures import Future - from .compile_fx import _CompileFxKwargs, CompiledFxGraph from .graph import GraphLowering from .ir import ChoiceCaller @@ -262,11 +261,11 @@ def get_global_cache(self) -> dict[str, Any]: def lookup( self, - choices: list[ChoiceCaller], + choices: list["ChoiceCaller"], op: str, inputs: str, - benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]], - ) -> dict[ChoiceCaller, float]: + benchmark: Optional[Callable[[Any], dict["ChoiceCaller", float]]], + ) -> dict["ChoiceCaller", float]: """ Check to see if we have benchmarked the given choice callers. For each choice caller: @@ -612,7 +611,7 @@ def get_hash(self, obj: Any) -> str: serialized_data = self.dumps(obj) return sha256_hash(serialized_data) - def debug_lines(self, inp: FxGraphHashDetails) -> list[str]: + def debug_lines(self, inp: "FxGraphHashDetails") -> list[str]: """ Get a printable string describing in more detail all the attributes comprising an object. Useful for debugging when one graph hashes @@ -729,8 +728,8 @@ class FxGraphHashDetails: def __init__( self, gm: torch.fx.GraphModule, - example_inputs: Sequence[InputType], - fx_kwargs: _CompileFxKwargs, + example_inputs: Sequence["InputType"], + fx_kwargs: "_CompileFxKwargs", inputs_to_check: Sequence[int], ) -> None: self.gm = gm @@ -746,7 +745,8 @@ def __init__( if type(v) in (set, OrderedSet): # noqa: set_linter # Special case to handle set params. Python sets can't be # ordered, so sort the elements and store them in a proxy. - self.fx_kwargs[k] = OrderedSetHolder(sorted(v)) # type: ignore[call-overload] + assert isinstance(v, Sequence) + self.fx_kwargs[k] = OrderedSetHolder(sorted(v)) else: self.fx_kwargs[k] = v @@ -847,8 +847,8 @@ def _get_custom_pass_detail( def compiled_fx_graph_hash( gm: torch.fx.GraphModule, - example_inputs: Sequence[InputType], - fx_kwargs: _CompileFxKwargs, + example_inputs: Sequence["InputType"], + fx_kwargs: "_CompileFxKwargs", inputs_to_check: Sequence[int], ) -> tuple[str, list[str]]: """ @@ -940,7 +940,7 @@ def _get_tmp_dir_for_key(key: str) -> str: return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) @staticmethod - def _filter_backed_symints(inputs: Sequence[InputType]) -> list[torch.SymInt]: + def _filter_backed_symints(inputs: Sequence["InputType"]) -> list[torch.SymInt]: """ Get the backed SymInt objects from the input list. Note that we can never have guards that depend on unbacked symint. @@ -960,11 +960,11 @@ def _get_shape_env() -> Optional[ShapeEnv]: @staticmethod def _lookup_graph( key: str, - example_inputs: Sequence[InputType], + example_inputs: Sequence["InputType"], local: bool, - remote_cache: Optional[RemoteCache[JsonDataTy]], - constants: CompiledFxGraphConstants, - ) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]: + remote_cache: Optional["RemoteCache[JsonDataTy]"], + constants: "CompiledFxGraphConstants", + ) -> tuple[Optional["CompiledFxGraph"], dict[str, Any]]: """ Lookup a compiled graph in the cache by key. On a hit, return the deserialized CompiledFxGraph object. On a miss, return None. @@ -976,7 +976,7 @@ def _lookup_graph( hints = [hint_int(s) for s in symints] def iterate_over_candidates() -> Generator[ - tuple[CompiledFxGraph, bytes], None, None + tuple["CompiledFxGraph", bytes], None, None ]: if local: subdir = FxGraphCache._get_tmp_dir_for_key(key) @@ -1113,10 +1113,10 @@ def _write_to_local_cache(key: str, content: bytes) -> None: @staticmethod def _save_graph( key: str, - compiled_graph: OutputCode, - example_inputs: Sequence[InputType], + compiled_graph: "OutputCode", + example_inputs: Sequence["InputType"], local: bool, - remote_cache: Optional[RemoteCache[JsonDataTy]], + remote_cache: Optional["RemoteCache[JsonDataTy]"], ) -> None: """ Store a serialized CompiledFxGraph on disk. @@ -1229,8 +1229,8 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None: @staticmethod def prepare_key( gm: torch.fx.GraphModule, - example_inputs: Sequence[InputType], - fx_kwargs: _CompileFxKwargs, + example_inputs: Sequence["InputType"], + fx_kwargs: "_CompileFxKwargs", inputs_to_check: Sequence[int], remote: bool, ) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]: @@ -1264,7 +1264,7 @@ def prepare_key( return (key, debug_lines), {} @staticmethod - def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: + def get_remote_cache() -> Optional["RemoteCache[JsonDataTy]"]: """ Attempts to load the remote cache, returns None on error. """ @@ -1280,12 +1280,12 @@ def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: def load_with_key( key: str, debug_lines: list[str], - example_inputs: Sequence[InputType], + example_inputs: Sequence["InputType"], local: bool, - remote_cache: Optional[RemoteCache[JsonDataTy]], + remote_cache: Optional["RemoteCache[JsonDataTy]"], is_backward: bool, - constants: CompiledFxGraphConstants, - ) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]: + constants: "CompiledFxGraphConstants", + ) -> tuple[Optional["CompiledFxGraph"], dict[str, Any]]: """ Lookup the graph with the given key, and return results and metadata. Doesn't do any logging on its own, because AOTAutograd handles a cache miss @@ -1392,7 +1392,7 @@ class AotCodeCompiler: @classmethod def compile( cls, - graph: GraphLowering, + graph: "GraphLowering", wrapper_code: str, kernel_code: str, serialized_extern_kernel_nodes: Optional[str], @@ -1966,7 +1966,7 @@ def convert_arg(arg: Any) -> Any: result = [torch.tensor([]) if r is None else r for r in result] for i, r in enumerate(result): assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" - return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type] + return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) else: assert isinstance(result, torch.Tensor), op + " returns a non-tensor" return torch._C._aoti.unsafe_alloc_void_ptr_from_tensor(result) @@ -2308,7 +2308,8 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType: assert spec is not None module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore[union-attr] + assert spec.loader is not None + spec.loader.exec_module(module) return module @classmethod @@ -2515,7 +2516,9 @@ class HalideCodeCache(CppPythonBindingsCodeCache): ) @classmethod - def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[str]: + def _codegen_buffer( + cls, name: str, arg: "HalideInputSpec", cuda: bool + ) -> list[str]: assert arg.shape is not None assert arg.stride is not None and len(arg.shape) == len(arg.stride) assert arg.offset is not None @@ -2549,7 +2552,7 @@ def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[st ] @classmethod - def _codegen_glue(cls, meta: HalideMeta, headerfile: object) -> str: + def _codegen_glue(cls, meta: "HalideMeta", headerfile: object) -> str: is_cuda = meta.is_cuda() assert is_cuda is ("user_context" in meta.target) assert "no_runtime" in meta.target @@ -2657,7 +2660,7 @@ def find_header(name: str) -> str: @classmethod def generate_halide_async( - cls, meta: HalideMeta, source_code: str, submit_fn: Any = None + cls, meta: "HalideMeta", source_code: str, submit_fn: Any = None ) -> Callable[[], Any]: dirpath = Path( get_path( @@ -2797,6 +2800,7 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None: job() except subprocess.SubprocessError as e: if os.environ.get("HALIDE_REPRO") == "1": + cmd: list[Any] python, script, *cmd = getattr(e, "cmd", ("", "", "")) if os.path.basename(python).startswith("python"): code = open(script).read() @@ -2807,7 +2811,9 @@ class Out: def __repr__(self) -> str: return "out" - cmd[cmd.index("-o") + 1] = Out() # type: ignore[call-overload] + ci = cmd.index("-o") + assert isinstance(ci, int) + cmd[ci + 1] = Out() repl = textwrap.indent( textwrap.dedent( f"""\ @@ -2934,7 +2940,7 @@ def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]: def _load_triton_kernel_from_source( kernel_name: str, source_code: str -) -> CachingAutotuner: +) -> "CachingAutotuner": return getattr(PyCodeCache.load(source_code), kernel_name) @@ -3349,7 +3355,7 @@ def __init__( self.result_fn = result_fn self.future = future - def result(self) -> Callable[..., Any]: # type: ignore[override] + def result(self) -> Callable[..., Any]: return self.result_fn() @@ -3358,7 +3364,7 @@ class StaticAutotunerFuture(CodeCacheFuture): A statically launchable CachingAutotuner, loaded from TritonBundler """ - def __init__(self, static_autotuner: CachingAutotuner) -> None: + def __init__(self, static_autotuner: "CachingAutotuner") -> None: # Pickled version of CachingAutotuner self.static_autotuner = static_autotuner # This needs to be set in AsyncCompile.triton, in case @@ -3367,10 +3373,10 @@ def __init__(self, static_autotuner: CachingAutotuner) -> None: # since it can be very large. self.reload_kernel_from_src: Optional[Callable[[], Any]] = None - def result(self) -> CachingAutotuner: + def result(self) -> "CachingAutotuner": assert self.reload_kernel_from_src is not None with dynamo_timed("StaticAutotunerFuture.warm_precompile"): - self.static_autotuner.precompile( # type: ignore[union-attr] + self.static_autotuner.precompile( warm_cache_only=False, reload_kernel=self.reload_kernel_from_src, static_triton_bundle_key=None, # no need to save again From e3c6fa719df14932faeff11f5bd207d14d447fe7 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Tue, 8 Apr 2025 16:54:21 +0000 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- torch/_inductor/codecache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index bbbfe8703ae1f..6248d95927482 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -644,7 +644,7 @@ def get_str(obj: Any) -> str: def build_code_hash( - roots: list[str] | None, prefix: str, hasher: hashlib._Hash + roots: Optional[list[str]], prefix: str, hasher: hashlib._hashlib.HASH ) -> None: for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): spec = lib.module_finder.find_spec(lib.name, None) @@ -2086,7 +2086,7 @@ def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]: raise @classmethod - def _get_uncompiled_header(cls, device: str) -> str | None: + def _get_uncompiled_header(cls, device: str) -> Optional[str]: """ Given a device type, returns the path to a CPP header file to be precompiled. Currently, this is only utilized by the cpp_wrapper classes. @@ -2427,7 +2427,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache): ) @classmethod - def _get_uncompiled_header(cls, device: str) -> str | None: + def _get_uncompiled_header(cls, device: str) -> Optional[str]: """ Given a device type, returns the path to a CPP header file to be precompiled. Currently, this is only utilized by the cpp_wrapper classes. From fb970cfcad32f5c9c2efd4ee07f6a48e7e8c7a11 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Tue, 8 Apr 2025 17:53:28 +0000 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- torch/_inductor/codecache.py | 2 +- torch/_inductor/codegen/common.py | 38 +++++++++++++++---------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6248d95927482..c5f76ff0d087a 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -644,7 +644,7 @@ def get_str(obj: Any) -> str: def build_code_hash( - roots: Optional[list[str]], prefix: str, hasher: hashlib._hashlib.HASH + roots: Optional[list[str]], prefix: str, hasher: "hashlib._Hash" ) -> None: for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): spec = lib.module_finder.find_spec(lib.name, None) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index d9b5929cc4a8f..db89bec7e8d60 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -348,7 +348,7 @@ def get_backend_features( if isinstance(device, torch.device): device_type = device.type else: - assert isinstance(device, str) + assert isinstance(device, str), type(device) device_type = device device = torch.device(device_type) scheduling_ctor = get_scheduling_for_device(device_type) @@ -361,7 +361,7 @@ def has_backend_feature( device: Union[torch.device, str, None], feature: BackendFeature ) -> bool: """See also V.graph.has_feature""" - assert isinstance(feature, BackendFeature) + assert isinstance(feature, BackendFeature), type(feature) return feature in get_backend_features(device) @@ -478,7 +478,7 @@ def register_device_op_overrides( def get_device_op_overrides(device: str) -> DeviceOpOverrides: - assert isinstance(device, str) + assert isinstance(device, str), type(device) if not device_op_overrides_dict: from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401 @@ -561,7 +561,7 @@ def check_dtype( elif config.test_configs.static_cpp_dtype_assert and backend == "cpp": from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP - assert isinstance(var, CppCSEVariable) + assert isinstance(var, CppCSEVariable), type(var) if dtype == torch.bool: if var.is_vec: is_same_dt = f"IsVecMaskType::value" @@ -622,13 +622,11 @@ def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]: return None if node.target == operator.getitem: - from ..ir import IRNode - node_arg = node.args[0] - assert isinstance(node_arg, IRNode) + assert isinstance(node_arg, torch.fx.Node), type(node_arg) return self.deduce_node_dtype(node_arg) - assert isinstance(node.target, str) + assert isinstance(node.target, str), type(node.target) if node.target.startswith("masked_subblock"): return self.deduce_node_dtype_by_subgraph(node) @@ -674,8 +672,8 @@ def propagate_scheduler_node(cls, node: "SchedulerNode") -> Optional[torch.dtype from ..loop_body import LoopBody from ..scheduler import SchedulerNode - assert isinstance(node, SchedulerNode) - assert isinstance(node._body, LoopBody) + assert isinstance(node, SchedulerNode), type(node) + assert isinstance(node._body, LoopBody), type(node._body) return DataTypePropagation.propagate_loopbody(node._body) @@ -1283,7 +1281,7 @@ class DeferredLine(DeferredLineBase): def __init__(self, name: str, line: str): super().__init__(line) self.name = name - assert not isinstance(line, DeferredLineBase) + assert not isinstance(line, DeferredLineBase), type(line) def __call__(self) -> Optional[str]: if not is_buffer_removed(self.name): @@ -1399,7 +1397,7 @@ def output(self, name: str) -> str: return self._lookup("out_ptr", self.output_buffers, name) def make_inplace(self, input_name: str, output_name: str) -> None: - assert output_name not in self.inplace_buffers + assert output_name not in self.inplace_buffers, output_name if input_name in self.inplace_buffers: buf = self.inplace_buffers[input_name] assert not isinstance(buf, RemovedArg) @@ -1461,7 +1459,7 @@ def workspace(self, nbytes: sympy.Expr, zero_fill: bool) -> tuple[str, int]: assert ( existing_arg.inner_name != arg.inner_name and existing_arg.outer_name != arg.outer_name - ) + ), existing_arg self.workspace_args.append(arg) return arg.inner_name, 0 @@ -1489,7 +1487,7 @@ def semaphores(self, min_size: sympy.Expr) -> str: ) for existing_arg in self.workspace_args: if existing_arg.inner_name == arg.inner_name: - assert arg == existing_arg + assert arg == existing_arg, (arg, existing_arg) self.workspace_args.append(arg) return arg.inner_name @@ -1676,7 +1674,7 @@ def __init__( dtype: Optional[torch.dtype] = None, ): super().__init__() - assert isinstance(bounds, ValueRanges) + assert isinstance(bounds, ValueRanges), type(bounds) self.name = name self.bounds = bounds self.use_count = 1 # track how many times this expression is used @@ -1806,7 +1804,7 @@ def generate( elif isinstance(expr, DeferredLineBase): cache_key = expr.line else: - assert isinstance(expr, str) + assert isinstance(expr, str), type(expr) cache_key = expr var = self.try_get(cache_key) if not var: @@ -2047,7 +2045,7 @@ def indirect_assert( ) -> str: if isinstance(var, CSEVariable): var = str(var) - assert isinstance(var, str) + assert isinstance(var, str), type(var) assert lower is None or isinstance(lower, str) assert upper is None or isinstance(upper, str) if lower and upper: @@ -2393,7 +2391,9 @@ def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[A fx_node = V.interpreter.current_node if fx_node.target == name and self.kernel.node_to_bounds is not None: - assert isinstance(self.kernel.node_to_bounds, dict) + assert isinstance(self.kernel.node_to_bounds, dict), type( + self.kernel.node_to_bounds + ) return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown()) elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): # These create lots of inner strings. We would need to compute the bounds at the ops @@ -2428,7 +2428,7 @@ def indirect_indexing( ) -> sympy.Symbol: if isinstance(size, int): size = sympy.Integer(size) - assert isinstance(size, sympy.Expr), size + assert isinstance(size, sympy.Expr), (type(size), size) # Skip CSE since this doesn't return an expression if var.bounds.lower < 0: From 2741c5fb1799b9b133032fc2d09e1bc8208284fd Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Wed, 9 Apr 2025 11:04:35 +0000 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- torch/_inductor/codecache.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c5f76ff0d087a..29435252f5b15 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -743,8 +743,7 @@ def __init__( if type(v) in (set, OrderedSet): # noqa: set_linter # Special case to handle set params. Python sets can't be # ordered, so sort the elements and store them in a proxy. - assert isinstance(v, Sequence) - self.fx_kwargs[k] = OrderedSetHolder(sorted(v)) + self.fx_kwargs[k] = OrderedSetHolder(sorted(v)) # type: ignore[call-overload] else: self.fx_kwargs[k] = v From c46ec03a541e1c50c6a23481fad1ea9b6e4a2a9a Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Thu, 10 Apr 2025 13:04:53 +0000 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- .../pr_time_benchmarks/expected_results.csv | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 931354033f676..e1e2f3b81f3af 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,32 +1,32 @@ -add_loop_eager,compile_time_instruction_count,2851000000,0.015 +add_loop_eager,compile_time_instruction_count,2963000000,0.015 -add_loop_eager_dynamic,compile_time_instruction_count,5541000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,5675000000,0.025 -add_loop_inductor,compile_time_instruction_count,28310000000,0.015 +add_loop_inductor,compile_time_instruction_count,29130000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,41190000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43000000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,24600000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,25530000000,0.015 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,960700000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,972000000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17420000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18110000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15710000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16210000000,0.015 @@ -34,44 +34,44 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,97140000 -update_hint_regression,compile_time_instruction_count,1543000000,0.02 +update_hint_regression,compile_time_instruction_count,1610000000,0.02 -float_args,compile_time_instruction_count,413700000,0.015 +float_args,compile_time_instruction_count,418700000,0.015 -sum_floordiv_regression,compile_time_instruction_count,975700000,0.015 +sum_floordiv_regression,compile_time_instruction_count,987900000,0.015 -symint_sum,compile_time_instruction_count,3104000000,0.015 +symint_sum,compile_time_instruction_count,3222000000,0.015 -symint_sum_loop,compile_time_instruction_count,4052000000,0.015 +symint_sum_loop,compile_time_instruction_count,4216000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2019000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2056000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5844000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5918000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,7989000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8553000000,0.015 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1746000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1878000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3645000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3788000000,0.015 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9999000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10280000000,0.015