8000 [inductor] propagate shapes in CSEVariable · pytorch/pytorch@ecc11f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit ecc11f5

Browse files
committed
[inductor] propagate shapes in CSEVariable
ghstack-source-id: 540faf4 Pull Request resolved: #152198
1 parent 67f7524 commit ecc11f5

File tree

8 files changed

+341
-84
lines changed

8 files changed

+341
-84
lines changed

torch/_inductor/codegen/common.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474
# causes typing errors in subclasses (defined in other files).
7575
OpVarT = str
7676

77+
ShapeType = Optional[Sequence[Union[int, str]]]
78+
7779
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
7880
log = logging.getLogger(__name__)
7981

@@ -1645,13 +1647,15 @@ def __init__(
16451647
name: str,
16461648
bounds: ValueRanges[Any],
16471649
dtype: Optional[torch.dtype] = None,
1650+
shape: Optional[ShapeType] = None,
16481651
):
16491652
super().__init__()
16501653
assert isinstance(bounds, ValueRanges)
16511654
self.name = name
16521655
self.bounds = bounds
16531656
self.use_count = 1 # track how many times this expression is used
16541657
self.dtype = dtype
1658+
self.shape = shape
16551659

16561660
def __str__(self) -> str:
16571661
return self.name
@@ -1761,6 +1765,7 @@ def generate(
17611765
write: bool = True,
17621766
assignment: bool = True,
17631767
dtype: Optional[torch.dtype] = None,
1768+
shape: Optional[ShapeType] = None,
17641769
) -> CSEVariableType:
17651770
if isinstance(expr, OpsValue):
17661771
expr = expr.value
@@ -1782,7 +1787,7 @@ def generate(
17821787
cache_key = expr
17831788
var = self.try_get(cache_key)
17841789
if not var:
1785-
var = self.newvar(bounds, dtype)
1790+
var = self.newvar(bounds, dtype, shape)
17861791
self.put(cache_key, var)
17871792
if write:
17881793
if V.kernel.current_node:
@@ -1828,9 +1833,10 @@ def newvar(
18281833
self,
18291834
bounds: ValueRanges[Any] = ValueRanges.unknown(),
18301835
dtype: Optional[torch.dtype] = None,
1836+
shape: Optional[ShapeType] = None,
18311837
) -> CSEVariableType:
18321838
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
1833-
var = V.kernel.create_cse_var(var_name, bounds, dtype)
1839+
var = V.kernel.create_cse_var(var_name, bounds, dtype, shape)
18341840
self.varname_map[var_name] = var
18351841
return var
18361842

@@ -1839,11 +1845,12 @@ def namedvar(
18391845
name: str,
18401846
bounds: ValueRanges[Any] = ValueRanges.unknown(),
18411847
dtype: Optional[torch.dtype] = None,
1848+
shape: Optional[ShapeType] = None,
18421849
) -> CSEVariableType:
18431850
torch._check_value(
18441851
name not in self.varname_map, lambda: f"duplicate name: {name}"
18451852
)
1846-
var = V.kernel.create_cse_var(name, bounds, dtype)
1853+
var = V.kernel.create_cse_var(name, bounds, dtype, shape)
18471854
self.varname_map[name] = var
18481855
return var
18491856

@@ -2319,7 +2326,7 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) ->
23192326

23202327
output_idx = 0
23212328

2322-
def do_cse(v: str) -> CSEVariable:
2329+
def do_cse(v: Union[str, CSEVariable]) -> CSEVariable:
23232330
# we tree_map over the output, so we need to fetch corresponding dtype
23242331
nonlocal output_idx
23252332
var_dtype: torch.dtype = (
@@ -2338,6 +2345,7 @@ def do_cse(v: str) -> CSEVariable:
23382345
v,
23392346
bounds=bounds,
23402347
dtype=output_dtype,
2348+
shape=getattr(v, "shape", None),
23412349
)
23422350

23432351
csevar.update_on_args(name, args, kwargs)
@@ -2427,7 +2435,13 @@ def indirect_indexing(
24272435
pos = var.bounds & ValueRanges(0, int_oo)
24282436
new_bounds 10000 = new_bounds | pos
24292437

2430-
var = self.kernel.cse.generate(self.kernel.compute, stm, bounds=new_bounds)
2438+
var = self.kernel.cse.generate(
2439+
self.kernel.compute,
2440+
stm,
2441+
bounds=new_bounds,
2442+
dtype=var.dtype,
2443+
shape=var.shape,
2444+
)
24312445

24322446
sympy_var = self.parent_handler.indirect_indexing(var, size, check)
24332447
if generate_assert(check):

torch/_inductor/codegen/cpp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -902,8 +902,8 @@ def frexp(x):
902902
return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys)
903903

904904
code = BracesBuffer()
905-
exponent = V.kernel.cse.newvar(dtype=torch.int32)
906-
mantissa = V.kernel.cse.newvar(dtype=x.dtype)
905+
exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape)
906+
mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape)
907907
code.writeline(f"int32_t {exponent};")
908908
code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});")
909909
V.kernel.compute.splice(code)

torch/_inductor/codegen/cpp_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..dependencies import Dep
2323
from ..loop_body import LoopBody
2424
from ..scheduler import BaseSchedulerNode, SchedulerBuffer
25+
from ..shape_propagation import ShapeType
2526
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
2627
from ..virtualized import ops, OpsValue, V
2728
from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext
@@ -144,8 +145,9 @@ def __init__(
144145
name,
145146
bounds: ValueRanges[Any],
146147
dtype: Optional[torch.dtype] = None,
148+
shape: Optional[ShapeType] = None,
147149
) -> None:
148-
super().__init__(name, bounds, dtype)
150+
super().__init__(name, bounds, dtype, shape=shape)
149151
self.is_vec = False
150152
self.dependent_itervars = OrderedSet[sympy.Symbol]()
151153

torch/_inductor/codegen/halide.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from collections.abc import Sequence
5555

5656
from ..ops_handler import ReductionType, StoreMode
57+
from ..shape_propagation import ShapeType
5758

5859
log = logging.getLogger(__name__)
5960

@@ -556,6 +557,7 @@ def masked(mask, body, other):
556557
f"hl.cast({result.name}.type(), {halide_constant(other)})",
557558
[],
558559
bounds=ValueRanges.wrap(other),
560+
shape=result.shape,
559561
)
560562
# TODO(jansel): look into removing the where in the same places triton does
561563
return ops.where(new_mask, result, other)
@@ -576,8 +578,9 @@ def __init__(
576578
name,
577579
bounds: ValueRanges[Any],
578580
dtype: Optional[torch.dtype] = None,
581+
shape: Optional[ShapeType] = None,
579582
) -> None:
580-
super().__init__(name, bounds, dtype)
583+
super().__init__(name, bounds, dtype, shape=shape)
581584
self.used_dims: Optional[list[sympy.Symbol]] = None
582585

583586
def update_on_args(self, name, args, kwargs):
@@ -1196,12 +1199,13 @@ def reduction(
11961199
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
11971200
reduction_vars = OrderedSet(self.reduction_renames)
11981201
result_var = self.newfunc(
1199-
[v for v in value.used_dims if v not in reduction_vars]
1202+
[v for v in value.used_dims if v not in reduction_vars],
12001203
)
12011204
if reduction_vars - OrderedSet(value.used_dims):
12021205
value = self.genfunc(
12031206
f"{value}",
12041207
self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))),
1208+
shape=value.shape,
12051209
)
12061210
value_str = value.subs_str(self.reduction_renames)
12071211
default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
@@ -1291,7 +1295,9 @@ def scan(
12911295
else:
12921296
values.append(
12931297
self.genfunc(
1294-
f"{value}", [*value.used_dims, [*self.reduction_renames][:1]]
1298+
f"{value}",
1299+
[*value.used_dims, [*self.reduction_renames][:1]],
1300+
shape=value.shape,
12951301
)
12961302
)
12971303
all_used_dims.update(value.used_dims)
@@ -1355,15 +1361,20 @@ def maybe_tuple(x):
13551361
return tuple(unpack_vars)
13561362

13571363
def genfunc(
1358-
self, line, used_dims, *, bounds=ValueRanges.unknown()
1364+
self,
1365+
line,
1366+
used_dims,
1367+
*,
1368+
bounds=ValueRanges.unknown(),
1369+
shape=None,
13591370
) -> HalideCSEVariable:
1360-
var = self.cse.generate(self.body, line, bounds=bounds)
1371+
var = self.cse.generate(self.body, line, bounds=bounds, shape=shape)
13611372
assert isinstance(var, HalideCSEVariable)
13621373
var.used_dims = used_dims
13631374
return var
13641375

1365-
def newfunc(self, used_dims) -> HalideCSEVariable:
1366-
var = self.cse.newvar()
1376+
def newfunc(self, used_dims, *, shape=None) -> HalideCSEVariable:
1377+
var = self.cse.newvar(shape=shape)
13671378
assert isinstance(var, HalideCSEVariable)
13681379
var.used_dims = used_dims
13691380
return var

0 commit comments

Comments
 (0)
0