8000 [aoti] fix corner case in unbacked replacements for atomically_apply_… · pytorch/pytorch@77428b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 77428b8

Browse files
committed
[aoti] fix corner case in unbacked replacements for atomically_apply_size_hint
ghstack-source-id: f4396c4 Pull Request resolved: #153768
1 parent 8ac82a1 commit 77428b8

File tree

2 files changed

+144
-29
lines changed

2 files changed

+144
-29
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,13 +1299,7 @@ def forward(self, values, repeats, mask, embeddings, x, z, scalar):
12991299

13001300
unbacked_add_expr = backed + unbacked
13011301
repeated = x.repeat(unbacked_add_expr, 1)
1302-
return torch.cat(
1303-
[
1304-
repeated,
1305-
index_select,
1306-
],
1307-
dim=1,
1308-
)
1302+
return torch.cat([repeated, index_select], dim=1)
13091303

13101304
example_inputs = (
13111305
torch.ones(64, dtype=torch.int64, device=self.device),
@@ -1327,6 +1321,115 @@ def forward(self, values, repeats, mask, embeddings, x, z, scalar):
13271321
}
13281322
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
13291323

1324+
def test_size_with_unbacked_add_expr_transitive(self):
1325+
# Edge case with torch._check(expr1, expr2) + torch._check(expr2, unbacked).
1326+
# When generating example input sizes for autotuning, it should coalesce
1327+
# expr1, expr2, unbacked into a single size.
1328+
if self.device != GPU_TYPE:
1329+
raise unittest.SkipTest("requires GPU")
1330+
1331+
class Repro(torch.nn.Module):
1332+
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
1333+
u0, u1, random_unbacked = lst.tolist()
1334+
torch._check_is_size(u0)
1335+
torch._check_is_size(u1)
1336+
backed = z.size(0)
1337+
backed1 = z.size(1)
1338+
1339+
repeated = x.repeat(backed + u0, 1)
1340+
repeated1 = y.repeat(backed1 + u1, 1)
1341+
out = torch.empty_like(repeated)
1342+
add_kernel[(out.numel(),)](
1343+
repeated, repeated, out, out.numel(), BLOCK_SIZE=2
1344+
)
1345+
1346+
torch._check(repeated1.size(0) == out.size(0))
1347+
torch._check(out.size(0) == random_unbacked)
1348+
1349+
index = torch.repeat_interleave(values, repeats)
1350+
index_select = torch.index_select(embeddings, 0, index)
1351+
1352+
cat = torch.cat([out, index_select], dim=1)
1353+
add = repeated + repeated1
1354+
return cat, add
1355+
1356+
example_inputs = (
1357+
torch.ones(64, dtype=torch.int64, device=self.device),
1358+
torch.ones(64, dtype=torch.int64, device=self.device) * 24,
1359+
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
1360+
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
1361+
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
1362+
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
1363+
torch.ones(758, 758, dtype=torch.int64, device=self.device),
1364+
torch.tensor(
1365+
[10, 10, 2 * (758 + 10)], dtype=torch.int32, device=self.device
1366+
),
1367+
)
1368+
spec = {
1369+
"values": (Dim.DYNAMIC,),
1370+
"repeats": (Dim.DYNAMIC,),
1371+
"mask": (Dim.DYNAMIC,),
1372+
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
1373+
"x": (Dim.DYNAMIC, Dim.STATIC),
1374+
"y": (Dim.DYNAMIC, Dim.STATIC),
1375+
"z": (Dim.DYNAMIC, Dim.DYNAMIC),
1376+
"lst": (Dim.STATIC,),
1377+
}
1378+
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
1379+
1380+
@config.patch({"unbacked_symint_fallback": 1024})
1381+
def test_size_with_unbacked_add_and_mul_expr(self):
1382+
# Edge case with torch._check(add_expr, mul_expr). When generating example
1383+
# input sizes for autotuning, make sure they coalesce into a single size.
1384+
if self.device != GPU_TYPE:
1385+
raise unittest.SkipTest("requires GPU")
1386+
1387+
class Repro(torch.nn.Module):
1388+
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
1389+
u0, u1, u2 = lst.tolist()
1390+
torch._check_is_size(u0)
1391+
torch._check_is_size(u1)
1392+
torch._check_is_size(u2)
1393+
backed = z.size(0)
1394+
backed1 = z.size(1)
1395+
1396+
unbacked_add_expr = backed + u0
1397+
unbacked_mul_expr = backed1 + (u1 * u2)
1398+
repeated0 = x.repeat(unbacked_add_expr, 1)
1399+
repeated1 = y.repeat(unbacked_mul_expr, 1)
1400+
out0 = torch.empty_like(repeated0)
1401+
out1 = torch.empty_like(repeated1)
1402+
add_kernel[(out0.numel(),)](
1403+
repeated0, repeated0, out0, out0.numel(), BLOCK_SIZE=2
1404+
)
1405+
add_kernel[(out1.numel(),)](
1406+
repeated1, repeated1, out1, out1.numel(), BLOCK_SIZE=2
1407+
)
1408+
1409+
return torch.cat([out1, out0], dim=1)
1410+
1411+
example_inputs = (
1412+
torch.ones(64, dtype=torch.int64, device=self.device),
1413+
torch.ones(64, dtype=torch.int64, device=self.device) * 24,
1414+
torch.ones((768,), dtype=torch.int64, device=self.device).bool(),
1415+
torch.randn((401, 8), dtype=torch.bfloat16, device=self.device),
1416+
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
1417+
torch.randn((2, 256), dtype=torch.bfloat16, device=self.device),
1418+
torch.ones(758, 758, dtype=torch.int64, device=self.device),
1419+
torch.tensor([10, 5, 2], dtype=torch.int32, device=self.device),
1420+
)
1421+
spec = {
1422+
"values": (Dim.DYNAMIC,),
1423+
"repeats": (Dim.DYNAMIC,),
1424+
"mask": (Dim.DYNAMIC,),
1425+
"embeddings": (Dim.DYNAMIC, Dim.STATIC),
1426+
"x": (Dim.DYNAMIC, Dim.STATIC),
1427+
"y": (Dim.DYNAMIC, Dim.STATIC),
1428+
"z": (Dim.DYNAMIC, Dim.DYNAMIC),
1429+
"lst": (Dim.STATIC,),
1430+
}
1431+
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
1432+
13301433
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
13311434
def test_fallback_kernel_with_symexpr_output(self):
13321435
if self.device != GPU_TYPE:

torch/_inductor/sizevars.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
import sympy
99
from sympy import Expr
1010

11-
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv
11+
from torch.fx.experimental.symbolic_shapes import (
12+
free_unbacked_symbols,
13+
has_free_unbacked_symbols,
14+
ShapeEnv,
15+
)
1216
from torch.utils._ordered_set import OrderedSet
1317
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
1418
from torch.utils._sympy.symbol import symbol_is_type, SymT
@@ -62,7 +66,7 @@ def __init__(self, shape_env=None) -> None:
6266
self.shape_env = shape_env
6367
self.var_to_val = self.shape_env.var_to_val
6468
self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements
65-
self.unbacked_replacements: dict[Expr, Expr] = {}
69+
self.unbacked_replacements: Optional[dict[Expr, Expr]] = None
6670
# Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
6771
# The basic idea is if we have some complicated sympy expression
6872
# f(s0), we may choose to precompute it on the host and then replace
@@ -639,7 +643,7 @@ def _stride_vars(
639643
)
640644
return strides
641645

642-
def _get_unbacked_replacements(self, expr: Expr) -> dict[Expr, Expr]:
646+
def _get_unbacked_replacements(self) -> dict[Expr, Expr]:
643647
"""
644648
This helps with covering unbacked symint cases where you may have two
645649
expressions: s0 + u0 and u1. And s0 + u0 is known to be equal to u1
@@ -649,33 +653,41 @@ def _get_unbacked_replacements(self, expr: Expr) -> dict[Expr, Expr]:
649653
hint for both s0 + u0 and u1, but it first needs to know they are equal.
650654
Then it can substitute s0 + u0 for u1.
651655
"""
652-
if expr in self.unbacked_replacements:
653-
return self.unbacked_replacements[expr]
656+
if self.unbacked_replacements is not None:
657+
return self.unbacked_replacements
654658

655-
runtime_asserts = itertools.chain.from_iterable(
656-
self.shape_env.deferred_runtime_asserts.get(u, [])
657-
for u in free_unbacked_symbols(expr)
658-
)
659-
equalities = (
660-
assertion.expr
661-
for assertion in runtime_asserts
662-
if isinstance(assertion.expr, sympy.Equality)
663-
)
664-
replacements = {eq.rhs: eq.lhs for eq in equalities}
659+
self.unbacked_replacements = {}
660+
for assertions in self.shape_env.deferred_runtime_asserts.values():
661+
for assertion in assertions:
662+
if not isinstance(assertion.expr, sympy.Equality):
663+
continue
665664

666-
self.unbacked_replacements[expr] = replacements
667-
return replacements
665+
lhs, rhs = assertion.expr.lhs, assertion.expr.rhs
666+
l2r = lhs.compare(rhs) == 1 # see sympy.Basic.compare
667+
src = lhs if l2r else rhs
668+
dst = rhs if l2r else lhs
669+
self.unbacked_replacements[src] = dst
670+
return self.unbacked_replacements
668671

669672
def atomically_apply_size_hint(
670673
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
671674
) -> Union[Expr, int]:
672-
if isinstance(expr, int):
675+
if isinstance(expr, (int, sympy.Integer)):
673676
return int(expr)
674677

675-
# Make sure to substitute with the factored version
676-
# e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0
677-
unbacked_replacements = self._get_unbacked_replacements(expr)
678-
expr = sympy.factor(expr).subs(unbacked_replacements)
678+
if has_free_unbacked_symbols(expr):
679+
680+
def _sub_unbacked_exprs(expr: Expr) -> Expr:
681+
replacements = self._get_unbacked_replacements()
682+
while True:
683+
new_expr = expr.subs(replacements)
684+
if new_expr == expr:
685+
return new_expr
686+
expr = sympy.factor(new_expr)
687+
688+
# Make sure to substitute with the factored version
689+
# e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0
690+
expr = _sub_unbacked_exprs(sympy.factor(expr))
679691

680692
# For multiple expressions that depend on an unbacked symint,
681693
# we want to compute them consistently for a size hint we have chosen.

0 commit comments

Comments
 (0)
0