10
10
import math
11
11
import operator
12
12
import re
13
+ import typing
13
14
from enum import auto , Enum
14
15
from itertools import chain
15
16
from typing import (
16
17
Any ,
17
18
Callable ,
18
19
cast ,
19
20
ClassVar ,
21
+ Generic ,
20
22
Iterator ,
23
+ MutableMapping ,
21
24
NamedTuple ,
22
25
Optional ,
23
26
TYPE_CHECKING ,
24
27
Union ,
25
28
)
29
+ from typing_extensions import TypeVar
26
30
27
31
import sympy
28
32
44
48
generate_assert ,
45
49
IndentedBuffer ,
46
50
ir_dataclass ,
51
+ ScopedDict ,
47
52
sympy_dot ,
48
53
sympy_subs ,
49
54
unique ,
52
57
53
58
54
59
if TYPE_CHECKING :
55
- from typing import Never , TypeVar
56
-
57
- from ..ir import FixedLayout
60
+ from ..ir import FixedLayout , IRNode
58
61
from ..loop_body import LoopBody
59
- from ..scheduler import BaseScheduling , Scheduler
62
+ from ..scheduler import BaseScheduling , Scheduler , SchedulerNode
60
63
from .wrapper import PythonWrapperCodegen
61
64
62
65
_T = TypeVar ("_T" )
@@ -1341,6 +1344,18 @@ def call_names(self) -> Iterator[str]:
1341
1344
self .input_buffers .keys (), self .output_buffers .keys (), self .sizevars .keys ()
1342
1345
)
1343
1346
1347
+ def arg_name (self , name : str ) -> Optional [str ]:
1348
+ """
1349
+ Returns inner name of a given outer name.
1350
+ """
1351
+ inplaced = self .inplace_buffers .get (name , None )
1352
+ if inplaced is not None and not isinstance (inplaced , RemovedArg ):
1353
+ return inplaced .inner_name
1354
+ output_name = self .output_buffers .get (name , None )
1355
+ if output_name is not None and not isinstance (output_name , RemovedArg ):
1356
+ return output_name
1357
+ return self .input_buffers .get (name , None )
1358
+
1344
1359
def wrap_ptr_arg (self , buf : str , dtype : torch .dtype ) -> str :
1345
1360
return buf
1346
1361
@@ -1482,17 +1497,18 @@ class CSEVariable:
1482
1497
1483
1498
def __init__ (
1484
1499
self ,
1485
- name ,
1500
+ name : str ,
1486
1501
bounds : ValueRanges [Any ],
1487
1502
dtype : Optional [torch .dtype ] = None ,
1488
1503
):
1504
+ super ().__init__ ()
1489
1505
assert isinstance (bounds , ValueRanges )
1490
1506
self .name = name
1491
1507
self .bounds = bounds
1492
1508
self .use_count = 1 # track how many times this expression is used
1493
1509
self .dtype = dtype
1494
1510
1495
- def __str__ (self ):
1511
+ def __str__ (self ) -> str :
1496
1512
return self .name
1497
1513
1498
1514
def __hash__ (self ) -> int :
@@ -1501,68 +1517,86 @@ def __hash__(self) -> int:
1501
1517
def __eq__ (self , other ) -> bool :
1502
1518
return type (other ) == type (self ) and other .name == self .name
1503
1519
1504
- def update_on_args (self , name , args , kwargs ) :
1520
+ def update_on_args (self , name : str , args : Any , kwargs : Any ) -> None :
1505
1521
pass
1506
1522
1507
- def __repr__ (self ):
1523
+ def __repr__ (self ) -> str :
1508
1524
return f"{ self .__class__ .__name__ } ({ self .name !r} )"
1509
1525
1510
1526
1511
- class CSE :
1527
+ AugmentedKeyT = TypeVar ("AugmentedKeyT" , default = str )
1528
+ CSEVariableType = TypeVar ("CSEVariableType" , bound = CSEVariable , default = CSEVariable )
1529
+
1530
+ if TYPE_CHECKING :
1531
+ ReductionCacheKey = tuple [
1532
+ torch .dtype ,
1533
+ ReductionType ,
1534
+ Union [CSEVariable , tuple [CSEVariable , ...]],
1535
+ ]
1536
+
1537
+
1538
+ class CSE (Generic [CSEVariableType , AugmentedKeyT ]):
1512
1539
"""Common subexpression elimination"""
1513
1540
1514
1541
def __init__ (
1515
1542
self ,
1516
- prefix = "" ,
1517
- suffix = "" ,
1518
- name_prefix = "tmp" ,
1519
- iter_buffers = None ,
1520
- store_cache = None ,
1521
- reduction_cache = None ,
1522
- varname_map = None ,
1543
+ prefix : str = "" ,
1544
+ suffix : str = "" ,
1545
+ name_prefix : str = "tmp" ,
1546
+ iter_buffers : Optional [itertools .count [int ]] = None ,
1547
+ store_cache : Optional [MutableMapping [str , CSEVariableType ]] = None ,
1548
+ reduction_cache : Optional [
1549
+ MutableMapping [ReductionCacheKey , CSEVariableType ]
1550
+ ] = None ,
1551
+ varname_map : Optional [dict [str , CSEVariableType ]] = None ,
1523
1552
):
1524
1553
self .prefix = prefix
1525
1554
self .suffix = suffix
1526
- self ._cache = {}
1555
+ self ._cache : MutableMapping [ AugmentedKeyT , CSEVariableType ] = {}
1527
1556
self .name_prefix = name_prefix
1528
- self .store_cache = store_cache or {}
1529
- self .reduction_cache = reduction_cache or {}
1530
- self .iter_buffer_ids = iter_buffers or itertools .count ()
1531
- self .invalidated_stores = OrderedSet [str ]()
1532
- self .varname_map = varname_map or {}
1533
-
1534
- def invalidate (self , keep_vars : Union [OrderedSet [str ], OrderedSet [Never ]]):
1535
- for name , tmp in list (self .store_cache .items ()):
1557
+ self .store_cache : MutableMapping [str , CSEVariableType ] = store_cache or {}
1558
+ self .reduction_cache : MutableMapping [ReductionCacheKey , CSEVariableType ] = (
1559
+ reduction_cache or {}
1560
+ )
1561
+ self .iter_buffer_ids : itertools .count [int ] = iter_buffers or itertools .count ()
1562
+ self .invalidated_stores : OrderedSet [str ] = OrderedSet ()
1563
+ self .varname_map : dict [str , CSEVariableType ] = varname_map or {}
1564
+
1565
+ def invalidate (self , keep_vars : OrderedSet [CSEVariable ]):
1566
+ for name , tmp in [* self .store_cache .items ()]:
1536
1567
if tmp not in keep_vars :
1537
1568
del self .store_cache [name ]
1538
1569
self .invalidated_stores .add (name )
1539
- self ._cache = {k : v for k , v in self ._cache .items () if v in keep_vars }
1570
+ if keep_vars :
1571
+ self ._cache = {k : v for k , v in self ._cache .items () if v in keep_vars }
1572
+ else :
1573
+ self ._cache = {}
1540
1574
1541
- def clone (self ):
1542
- # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
1575
+ def clone (self ) -> typing .Self :
1543
1576
return type (self )(
1544
1577
prefix = self .prefix ,
1545
1578
suffix = self .suffix ,
1546
1579
name_prefix = self .name_prefix ,
1547
1580
iter_buffers = self .iter_buffer_ids ,
1548
1581
store_cache = self .store_cache ,
1549
1582
varname_map = self .varname_map ,
1583
+ reduction_cache = self .reduction_cache ,
1550
1584
)
1551
1585
1552
- def augment_key (self , cache_key : object ) -> object :
1586
+ def augment_key (self , cache_key : str ) -> AugmentedKeyT :
1553
1587
"Override this method to augment cache key with backend specifics"
1554
- return cache_key
1588
+ return cast ( AugmentedKeyT , cache_key )
1555
1589
1556
- def put (self , cache_key : object , val : CSEVariable ) -> None :
1590
+ def put (self , cache_key : str , val : CSEVariableType ) -> None :
1557
1591
self ._cache [self .augment_key (cache_key )] = val
1558
1592
1559
- def contains (self , cache_key ) -> bool :
1593
+ def contains (self , cache_key : str ) -> bool :
1560
1594
return self .augment_key (cache_key ) in self ._cache
1561
1595
1562
- def try_get (self , cache_key : object ) -> Optional [CSEVariable ]:
1596
+ def try_get (self , cache_key : str ) -> Optional [CSEVariableType ]:
1563
1597
return self ._cache .get (self .augment_key (cache_key ), None )
1564
1598
1565
- def get (self , cache_key : object ) -> CSEVariable :
1599
+ def get (self , cache_key : str ) -> CSEVariableType :
1566
1600
return self ._cache [self .augment_key (cache_key )]
1567
1601
1568
1602
def generate (
@@ -1571,10 +1605,10 @@ def generate(
1571
1605
expr : Union [str , CSEVariable , OpsValue , IndentedBuffer , DeferredLineBase ],
1572
1606
* ,
1573
1607
bounds : ValueRanges [Any ] = ValueRanges .unknown (),
1574
- write = True ,
1575
- assignment = True ,
1608
+ write : bool = True ,
1609
+ assignment : bool = True ,
1576
1610
dtype : Optional [torch .dtype ] = None ,
1577
- ) -> CSEVariable :
1611
+ ) -> CSEVariableType :
1578
1612
if isinstance (expr , OpsValue ):
1579
1613
expr = expr .value
1580
1614
@@ -1585,7 +1619,7 @@ def generate(
1585
1619
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
1586
1620
expr .bounds = expr .bounds .tighten (bounds )
1587
1621
expr .use_count += 1
1588
- return expr
1622
+ return cast ( CSEVariableType , expr )
1589
1623
elif isinstance (expr , IndentedBuffer ):
1590
1624
cache_key = expr .getvalue ()
1591
1625
elif isinstance (expr , DeferredLineBase ):
@@ -1628,7 +1662,7 @@ def newvar(
1628
1662
self ,
1629
1663
bounds : ValueRanges [Any ] = ValueRanges .unknown (),
1630
1664
dtype : Optional [torch .dtype ] = None ,
1631
- ) -> CSEVariable :
1665
+ ) -> CSEVariableType :
1632
1666
var_name = f"{ self .name_prefix } { next (self .iter_buffer_ids )} "
1633
1667
var = V .kernel .create_cse_var (var_name , bounds , dtype )
1634
1668
self .varname_map [var_name ] = var
@@ -1639,7 +1673,7 @@ def namedvar(
1639
1673
name : str ,
1640
1674
bounds : ValueRanges [Any ] = ValueRanges .unknown (),
1641
1675
dtype : Optional [torch .dtype ] = None ,
1642
- ) -> CSEVariable :
1676
+ ) -> CSEVariableType :
1643
1677
torch ._check_value (
1644
1678
name not in self .varname_map , lambda : f"duplicate name: { name } "
1645
1679
)
@@ -1653,45 +1687,22 @@ def __init__(self) -> None:
1653
1687
super ().__init__ ()
1654
1688
self .exit_stack = contextlib .ExitStack ()
1655
1689
1656
- def __enter__ (self ):
1690
+ def __enter__ (self ) -> typing . Self :
1657
1691
self .exit_stack .__enter__ ()
1658
1692
return self
1659
1693
1660
- def __exit__ (self , exc_type , exc_val , exc_tb ) :
1694
+ def __exit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> None :
1661
1695
self .exit_stack .__exit__ (exc_type , exc_val , exc_tb )
1662
1696
1663
1697
1664
- class ScopedDict :
1665
- def __init__ (self , original_dict ):
1666
- self .original_dict = original_dict
1667
- self .new_items = {}
1668
-
1669
- def __getitem__ (self , key ):
1670
- if key in self .new_items :
1671
- return self .new_items [key ]
1672
- return self .original_dict [key ]
1673
-
1674
- def __setitem__ (self , key , value ):
1675
- self .new_items [key ] = value
1676
-
1677
- def __contains__ (self , key ):
1678
- return key in self .new_items or key in self .original_dict
1679
-
1680
- def get (self , key , default = None ):
1681
- if key in self .new_items :
1682
- return self .new_items [key ]
1683
- return self .original_dict .get (key , default )
1684
-
1685
-
1686
- class Kernel (CodeGen ):
1687
- newvar_prefix = ""
1688
- suffix = ""
1698
+ class Kernel (CodeGen , Generic [CSEVariableType ]):
1699
+ newvar_prefix : str = ""
1700
+ suffix : str = ""
1689
1701
overrides : Optional [Callable [[OpsHandler [Any ]], OpsHandler [Any ]]] = None
1690
- # TODO: these look dead, but with all the getattr it's hard to tell...
1691
- load_format : None = None
1692
- store_format : None = None
1693
1702
1694
- def __init__ (self , args = None , increase_kernel_count = True ):
1703
+ def __init__ (
1704
+ self , args : Optional [KernelArgs ] = None , increase_kernel_count : bool = True
1705
+ ) -> None :
1695
1706
super ().__init__ ()
1696
1707
if increase_kernel_count :
1697
1708
metrics .generated_kernel_count += 1
@@ -1703,13 +1714,13 @@ def __init__(self, args=None, increase_kernel_count=True):
1703
1714
self .num_load = 0
1704
1715
self .num_reduction = 0
1705
1716
1706
- self .cse : CSE = CSE (self .newvar_prefix , self .suffix )
1717
+ self .cse : CSE [ CSEVariableType , Any ] = CSE (self .newvar_prefix , self .suffix )
1707
1718
self .must_keep_buffers = OrderedSet [str ]()
1708
1719
self .store_buffer_names = OrderedSet [str ]()
1709
- self ._load_mask = None
1710
- self ._load_other = None
1720
+ self ._load_mask : Optional [ str ] = None
1721
+ self ._load_other : Union [ None , int , float ] = None
1711
1722
# OrderedSet in set_current_node
1712
- self .current_node = None
1723
+ self .current_node : Optional [ SchedulerNode ] = None
1713
1724
self .node_to_bounds : Optional [dict [torch .fx .Node , ValueRanges [Any ]]] = None
1714
1725
1715
1726
self .removed_buffers = OrderedSet [str ]()
@@ -1718,10 +1729,10 @@ def __init__(self, args=None, increase_kernel_count=True):
1718
1729
# key: the buffer to write
1719
1730
# value: the buffer to read and whose memory can be reused for
1720
1731
# the buffer specified by key
1721
- self .inplace_update_buffers = {}
1732
+ self .inplace_update_buffers : dict [ str , str ] = {}
1722
1733
# Set minimum number of elements processed per thread.
1723
1734
self .min_elem_per_thread = 1
1724
- self .kernel_name = None
1735
+ self .kernel_name : Optional [ str ] = None
1725
1736
1726
1737
@contextlib .contextmanager
1727
1738
def set_current_node (self , node ):
@@ -1735,7 +1746,7 @@ def set_current_node(self, node):
1735
1746
1736
1747
@contextlib .contextmanager
1737
1748
def swap_buffers (self , lb , cb = None , sb = None ):
1738
- def scope_cse (cse ):
1749
+ def scope_cse (cse : CSE [ CSEVariableType , Any ] ):
1739
1750
new_cse = cse .clone ()
1740
1751
new_cse ._cache = ScopedDict (cse ._cache )
1741
1752
new_cse .reduction_cache = ScopedDict (cse .reduction_cache )
@@ -2062,6 +2073,7 @@ def load(name: str, index: sympy.Expr) -> CSEVariable:
2062
2073
2063
2074
@staticmethod
2064
2075
def _update_store_cache (name : str , value : CSEVariable ):
2076
+ value = cast (CSEVariableType , value )
2065
2077
self .cse .store_cache [name ] = value
2066
2078
if self .current_node and name in V .graph .name_to_buffer :
2067
2079
buf = self .current_node .get_output (name )
@@ -2288,6 +2300,14 @@ def rename_indexing(self, index) -> sympy.Expr:
2288
2300
def create_cse_var (self , * args , ** kwargs ):
2289
2301
return CSEVariable (* args , ** kwargs )
2290
2302
2303
+ def arg_name (self , node : IRNode ) -> Optional [str ]:
2304
+ """
2305
+ Returns arg name of a given input or output node.
2306
+ """
2307
+ if node is None :
2308
+ return None
2309
+ return self .args .arg_name (node .get_name ())
2310
+
2291
2311
2292
2312
@dataclasses .dataclass
2293
2313
class OptimizationContext :
0 commit comments