10000 [inductor] Add typing to common.CSE (#145993) · pytorch/pytorch@e9f6e27 · GitHub
[go: up one dir, main page]

Skip to content

Commit e9f6e27

Browse files
janselpytorchmergebot
authored andcommitted
[inductor] Add typing to common.CSE (#145993)
Pull Request resolved: #145993 Approved by: https://github.com/yanboliang ghstack dependencies: #145916
1 parent 7a5239a commit e9f6e27

File tree

10 files changed

+193
-119
lines changed

10 files changed

+193
-119
lines changed

torch/_inductor/codegen/common.py

Lines changed: 98 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,23 @@
1010
import math
1111
import operator
1212
import re
13+
import typing
1314
from enum import auto, Enum
1415
from itertools import chain
1516
from typing import (
1617
Any,
1718
Callable,
1819
cast,
1920
ClassVar,
21+
Generic,
2022
Iterator,
23+
MutableMapping,
2124
NamedTuple,
2225
Optional,
2326
TYPE_CHECKING,
2427
Union,
2528
)
29+
from typing_extensions import TypeVar
2630

2731
import sympy
2832

@@ -44,6 +48,7 @@
4448
generate_assert,
4549
IndentedBuffer,
4650
ir_dataclass,
51+
ScopedDict,
4752
sympy_dot,
4853
sympy_subs,
4954
unique,
@@ -52,11 +57,9 @@
5257

5358

5459
if TYPE_CHECKING:
55-
from typing import Never, TypeVar
56-
57-
from ..ir import FixedLayout
60+
from ..ir import FixedLayout, IRNode
5861
from ..loop_body import LoopBody
59-
from ..scheduler import BaseScheduling, Scheduler
62+
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
6063
from .wrapper import PythonWrapperCodegen
6164

6265
_T = TypeVar("_T")
@@ -1341,6 +1344,18 @@ def call_names(self) -> Iterator[str]:
13411344
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
13421345
)
13431346

1347+
def arg_name(self, name: str) -> Optional[str]:
1348+
"""
1349+
Returns inner name of a given outer name.
1350+
"""
1351+
inplaced = self.inplace_buffers.get(name, None)
1352+
if inplaced is not None and not isinstance(inplaced, RemovedArg):
1353+
return inplaced.inner_name
1354+
output_name = self.output_buffers.get(name, None)
1355+
if output_name is not None and not isinstance(output_name, RemovedArg):
1356+
return output_name
1357+
return self.input_buffers.get(name, None)
1358+
13441359
def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str:
13451360
return buf
13461361

@@ -1482,17 +1497,18 @@ class CSEVariable:
14821497

14831498
def __init__(
14841499
self,
1485-
name,
1500+
name: str,
14861501
bounds: ValueRanges[Any],
14871502
dtype: Optional[torch.dtype] = None,
14881503
):
1504+
super().__init__()
14891505
assert isinstance(bounds, ValueRanges)
14901506
self.name = name
14911507
self.bounds = bounds
14921508
self.use_count = 1 # track how many times this expression is used
14931509
self.dtype = dtype
14941510

1495-
def __str__(self):
1511+
def __str__(self) -> str:
14961512
return self.name
14971513

14981514
def __hash__(self) -> int:
@@ -1501,68 +1517,86 @@ def __hash__(self) -> int:
15011517
def __eq__(self, other) -> bool:
15021518
return type(other) == type(self) and other.name == self.name
15031519

1504-
def update_on_args(self, name, args, kwargs):
1520+
def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
15051521
pass
15061522

1507-
def __repr__(self):
1523+
def __repr__(self) -> str:
15081524
return f"{self.__class__.__name__}({self.name!r})"
15091525

15101526

1511-
class CSE:
1527+
AugmentedKeyT = TypeVar("AugmentedKeyT", default=str)
1528+
CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable)
1529+
1530+
if TYPE_CHECKING:
1531+
ReductionCacheKey = tuple[
1532+
torch.dtype,
1533+
ReductionType,
1534+
Union[CSEVariable, tuple[CSEVariable, ...]],
1535+
]
1536+
1537+
1538+
class CSE(Generic[CSEVariableType, AugmentedKeyT]):
15121539
"""Common subexpression elimination"""
15131540

15141541
def __init__(
15151542
self,
1516-
prefix="",
1517-
suffix="",
1518-
name_prefix="tmp",
1519-
iter_buffers=None,
1520-
store_cache=None,
1521-
reduction_cache=None,
1522-
varname_map=None,
1543+
prefix: str = "",
1544+
suffix: str = "",
1545+
name_prefix: str = "tmp",
1546+
iter_buffers: Optional[itertools.count[int]] = None,
1547+
store_cache: Optional[MutableMapping[str, CSEVariableType]] = None,
1548+
reduction_cache: Optional[
1549+
MutableMapping[ReductionCacheKey, CSEVariableType]
1550+
] = None,
1551+
varname_map: Optional[dict[str, CSEVariableType]] = None,
15231552
):
15241553
self.prefix = prefix
15251554
self.suffix = suffix
1526-
self._cache = {}
1555+
self._cache: MutableMapping[AugmentedKeyT, CSEVariableType] = {}
15271556
self.name_prefix = name_prefix
1528-
self.store_cache = store_cache or {}
1529-
self.reduction_cache = reduction_cache or {}
1530-
self.iter_buffer_ids = iter_buffers or itertools.count()
1531-
self.invalidated_stores = OrderedSet[str]()
1532-
self.varname_map = varname_map or {}
1533-
1534-
def invalidate(self, keep_vars: Union[OrderedSet[str], OrderedSet[Never]]):
1535-
for name, tmp in list(self.store_cache.items()):
1557+
self.store_cache: MutableMapping[str, CSEVariableType] = store_cache or {}
1558+
self.reduction_cache: MutableMapping[ReductionCacheKey, CSEVariableType] = (
1559+
reduction_cache or {}
1560+
)
1561+
self.iter_buffer_ids: itertools.count[int] = iter_buffers or itertools.count()
1562+
self.invalidated_stores: OrderedSet[str] = OrderedSet()
1563+
self.varname_map: dict[str, CSEVariableType] = varname_map or {}
1564+
1565+
def invalidate(self, keep_vars: OrderedSet[CSEVariable]):
1566+
for name, tmp in [*self.store_cache.items()]:
15361567
if tmp not in keep_vars:
15371568
del self.store_cache[name]
15381569
self.invalidated_stores.add(name)
1539-
self._cache = {k: v for k, v in self._cache.items() if v in keep_vars}
1570+
if keep_vars:
1571+
self._cache = {k: v for k, v in self._cache.items() if v in keep_vars}
1572+
else:
1573+
self._cache = {}
15401574

1541-
def clone(self):
1542-
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
1575+
def clone(self) -> typing.Self:
15431576
return type(self)(
15441577
prefix=self.prefix,
15451578
suffix=self.suffix,
15461579
name_prefix=self.name_prefix,
15471580
iter_buffers=self.iter_buffer_ids,
15481581
store_cache=self.store_cache,
15491582
varname_map=self.varname_map,
1583+
reduction_cache=self.reduction_cache,
15501584
)
15511585

1552-
def augment_key(self, cache_key: object) -> object:
1586+
def augment_key(self, cache_key: str) -> AugmentedKeyT:
15531587
"Override this method to augment cache key with backend specifics"
1554-
return cache_key
1588+
return cast(AugmentedKeyT, cache_key)
15551589

1556-
def put(self, cache_key: object, val: CSEVariable) -> None:
1590+
def put(self, cache_key: str, val: CSEVariableType) -> None:
15571591
self._cache[self.augment_key(cache_key)] = val
15581592

1559-
def contains(self, cache_key) -> bool:
1593+
def contains(self, cache_key: str) -> bool:
15601594
return self.augment_key(cache_key) in self._cache
15611595

1562-
def try_get(self, cache_key: object) -> Optional[CSEVariable]:
1596+
def try_get(self, cache_key: str) -> Optional[CSEVariableType]:
15631597
return self._cache.get(self.augment_key(cache_key), None)
15641598

1565-
def get(self, cache_key: object) -> CSEVariable:
1599+
def get(self, cache_key: str) -> CSEVariableType:
15661600
return self._cache[self.augment_key(cache_key)]
15671601

15681602
def generate(
@@ -1571,10 +1605,10 @@ def generate(
15711605
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer, DeferredLineBase],
15721606
*,
15731607
bounds: ValueRanges[Any] = ValueRanges.unknown(),
1574-
write=True,
1575-
assignment=True,
1608+
write: bool = True,
1609+
assignment: bool = True,
15761610
dtype: Optional[torch.dtype] = None,
1577-
) -> CSEVariable:
1611+
) -> CSEVariableType:
15781612
if isinstance(expr, OpsValue):
15791613
expr = expr.value
15801614

@@ -1585,7 +1619,7 @@ def generate(
15851619
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
15861620
expr.bounds = expr.bounds.tighten(bounds)
15871621
expr.use_count += 1
1588-
return expr
1622+
return cast(CSEVariableType, expr)
15891623
elif isinstance(expr, IndentedBuffer):
15901624
cache_key = expr.getvalue()
15911625
elif isinstance(expr, DeferredLineBase):
@@ -1628,7 +1662,7 @@ def newvar(
16281662
self,
16291663
bounds: ValueRanges[Any] = ValueRanges.unknown(),
16301664
dtype: Optional[torch.dtype] = None,
1631-
) -> CSEVariable:
1665+
) -> CSEVariableType:
16321666
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
16331667
var = V.kernel.create_cse_var(var_name, bounds, dtype)
16341668
self.varname_map[var_name] = var
@@ -1639,7 +1673,7 @@ def namedvar(
16391673
name: str,
16401674
bounds: ValueRanges[Any] = ValueRanges.unknown(),
16411675
dtype: Optional[torch.dtype] = None,
1642-
) -> CSEVariable:
1676+
) -> CSEVariableType:
16431677
torch._check_value(
16441678
name not in self.varname_map, lambda: f"duplicate name: {name}"
16451679
)
@@ -1653,45 +1687,22 @@ def __init__(self) -> None:
16531687
super().__init__()
16541688
self.exit_stack = contextlib.ExitStack()
16551689

1656-
def __enter__(self):
1690+
def __enter__(self) -> typing.Self:
16571691
self.exit_stack.__enter__()
16581692
return self
16591693

1660-
def __exit__(self, exc_type, exc_val, exc_tb):
1694+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
16611695
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
16621696

16631697

1664-
class ScopedDict:
1665-
def __init__(self, original_dict):
1666-
self.original_dict = original_dict
1667-
self.new_items = {}
1668-
1669-
def __getitem__(self, key):
1670-
if key in self.new_items:
1671-
return self.new_items[key]
1672-
return self.original_dict[key]
1673-
1674-
def __setitem__(self, key, value):
1675-
self.new_items[key] = value
1676-
1677-
def __contains__(self, key):
1678-
return key in self.new_items or key in self.original_dict
1679-
1680-
def get(self, key, default=None):
1681-
if key in self.new_items:
1682-
return self.new_items[key]
1683-
return self.original_dict.get(key, default)
1684-
1685-
1686-
class Kernel(CodeGen):
1687-
newvar_prefix = ""
1688-
suffix = ""
1698+
class Kernel(CodeGen, Generic[CSEVariableType]):
1699+
newvar_prefix: str = ""
1700+
suffix: str = ""
16891701
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
1690-
# TODO: these look dead, but with all the getattr it's hard to tell...
1691-
load_format: None = None
1692-
store_format: None = None
16931702

1694-
def __init__(self, args=None, increase_kernel_count=True):
1703+
def __init__(
1704+
self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True
1705+
) -> None:
16951706
super().__init__()
16961707
if increase_kernel_count:
16971708
metrics.generated_kernel_count += 1
@@ -1703,13 +1714,13 @@ def __init__(self, args=None, increase_kernel_count=True):
17031714
self.num_load = 0
17041715
self.num_reduction = 0
17051716

1706-
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
1717+
self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix)
17071718
self.must_keep_buffers = OrderedSet[str]()
17081719
self.store_buffer_names = OrderedSet[str]()
1709-
self._load_mask = None
1710-
self._load_other = None
1720+
self._load_mask: Optional[str] = None
1721+
self._load_other: Union[None, int, float] = None
17111722
# OrderedSet in set_current_node
1712-
self.current_node = None
1723+
self.current_node: Optional[SchedulerNode] = None
17131724
self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None
17141725

17151726
self.removed_buffers = OrderedSet[str]()
@@ -1718,10 +1729,10 @@ def __init__(self, args=None, increase_kernel_count=True):
17181729
# key: the buffer to write
17191730
# value: the buffer to read and whose memory can be reused for
17201731
# the buffer specified by key
1721-
self.inplace_update_buffers = {}
1732+
self.inplace_update_buffers: dict[str, str] = {}
17221733
# Set minimum number of elements processed per thread.
17231734
self.min_elem_per_thread = 1
1724-
self.kernel_name = None
1735+
self.kernel_name: Optional[str] = None
17251736

17261737
@contextlib.contextmanager
17271738
def set_current_node(self, node):
@@ -1735,7 +1746,7 @@ def set_current_node(self, node):
17351746

17361747
@contextlib.contextmanager
17371748
def swap_buffers(self, lb, cb=None, sb=None):
1738-
def scope_cse(cse):
1749+
def scope_cse(cse: CSE[CSEVariableType, Any]):
17391750
new_cse = cse.clone()
17401751
new_cse._cache = ScopedDict(cse._cache)
17411752
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
@@ -2062,6 +2073,7 @@ def load(name: str, index: sympy.Expr) -> CSEVariable:
20622073

20632074
@staticmethod
20642075
def _update_store_cache(name: str, value: CSEVariable):
2076+
value = cast(CSEVariableType, value)
20652077
self.cse.store_cache[name] = value
20662078
if self.current_node and name in V.graph.name_to_buffer:
20672079
buf = self.current_node.get_output(name)
@@ -2288,6 +2300,14 @@ def rename_indexing(self, index) -> sympy.Expr:
22882300
def create_cse_var(self, *args, **kwargs):
22892301
return CSEVariable(*args, **kwargs)
22902302

2303+
def arg_name(self, node: IRNode) -> Optional[str]:
2304+
"""
2305+
Returns arg name of a given input or output node.
2306+
"""
2307+
if node is None:
2308+
return None
2309+
return self.args.arg_name(node.get_name())
2310+
22912311

22922312
@dataclasses.dataclass
22932313
class OptimizationContext:

torch/_inductor/codegen/cpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2646,7 +2646,7 @@ def load(self, name: str, index: sympy.Expr):
26462646
return super().load(name, index)
26472647
elif stride == 1:
26482648
# load contiguously
2649-
line = self._get_vec_load_line(var, index, dtype, self._load_mask)
2649+
line = self._get_vec_load_line(var, index, dtype, self._load_mask) # type: ignore[arg-type]
26502650
csevar = self.cse.generate(self.loads, line) # type: ignore[assignment]
26512651
else:
26522652
csevar = self._load_or_store_non_contiguous(var, index, dtype) # type: ignore[assignment]

torch/_inductor/codegen/cuda/cuda_kernel.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,16 +163,6 @@ def __init__(self, kernel_name) -> None:
163163
super().__init__()
164164
self.kernel_name = kernel_name
165165

166-
def arg_name(self, node: IRNode) -> Optional[str]:
167-
"""
168-
Returns arg name of a given input or output node.
169-
"""
170-
if node is None:
171-
return None
172-
return {**self.args.input_buffers, **self.args.output_buffers}.get(
173-
node.get_name(), None
174-
)
175-
176166
def check_not_null(self, node: IRNode) -> str:
177167
"""
178168
Generates code to check that a node is not null.
@@ -273,6 +263,7 @@ def call_kernel(
273263
"""
274264
wrapper = V.graph.wrapper_code
275265

266+
arg_types: list[Any]
276267
if V.graph.cpp_wrapper:
277268
# Make sure we initialize these kernels since they're exported as
278269
# C-style symbol names.

0 commit comments

Comments
 (0)
0