8000 [dynamic shapes] bound_sympy for size-oblivious min/max reasoning (#1… · pytorch/pytorch@13339ce · GitHub
[go: up one dir, main page]

Skip to content

Commit 13339ce

Browse files
pianpwkpytorchmergebot
authored andcommitted
[dynamic shapes] bound_sympy for size-oblivious min/max reasoning (#151242)
Differential Revision: D72978020 Pull Request resolved: #151242 Approved by: https://github.com/bobrenjc93
1 parent 74074fe commit 13339ce

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

test/dynamo/test_misc.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10876,13 +10876,18 @@ def cf(x):
1087610876
torch._check_is_size(u0)
1087710877
torch._check_is_size(u1)
1087810878
torch._check(u0 + u1 == 20)
10879+
10880+
y = 0
1087910881
if guard_size_oblivious(torch.sym_max(1, u0 + u1) == 20):
10880-
return torch.tensor(True)
10881-
else:
10882-
return torch.tensor(False)
10882+
y += 1
10883+
if guard_size_oblivious(torch.sym_max(1, u0**2 + u1 + 2) != 1):
10884+
y += 1
10885+
if guard_size_oblivious(torch.sym_min(1, u0) == 1):
10886+
y += 1
10887+
return y
1088310888

1088410889
# Previously would have thrown guard on data dependent
10885-
cf(torch.tensor([10, 10])).item()
10890+
self.assertEqual(cf(torch.tensor([10, 10])), 3)
1088610891

1088710892
@torch._dynamo.config.patch(capture_scalar_outputs=True)
1088810893
def test_guard_size_oblivious(self):

torch/fx/experimental/symbolic_shapes.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import sympy
4-
from sympy import Add, S
4+
from sympy import S
55

66

77
"""
@@ -76,6 +76,7 @@
7676
FloorToInt,
7777
IsNonOverlappingAndDenseIndicator,
7878
Max,
79+
Min,
7980
Mod,
8081
PythonMod,
8182
)
@@ -5934,24 +5935,22 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT:
59345935
expr = safe_expand(expr)
59355936
expr = self.replace(expr)
59365937

5937-
if size_oblivious and expr.has(Max):
5938-
max_replacements = {}
5939-
for atom in expr.atoms(Max):
5938+
if size_oblivious and (expr.has(Max) or expr.has(Min)): # type: ignore[has-type]
5939+
min_max_replacements = {}
5940+
for atom in (*expr.atoms(Max), *expr.atoms(Min)): # type: ignore[has-type]
59405941
if len(atom.args) > 2:
59415942
continue
59425943
a, b = atom.args
59435944
if b == 1 or b == 0:
59445945
a, b = b, a
59455946
if a == 1 or a == 0:
5946-
if (
5947-
isinstance(b, Add)
5948-
and len(b.free_symbols) == 2 # TODO: expand to N?
5949-
and b.free_symbols == set(b.atoms())
5950-
and all(x in self.size_like for x in b.free_symbols)
5951-
):
5952-
max_replacements[atom] = b
5953-
if max_replacements:
5954-
expr = expr.xreplace(max_replacements)
5947+
vr = self.bound_sympy(b, size_oblivious=True)
5948+
if vr.lower >= a:
5949+
min_max_replacements[atom] = b if atom.func is Max else a
5950+
elif vr.upper <= a:
5951+
min_max_replacements[atom] = a if atom.func is Max else b
5952+
if min_max_replacements:
5953+
expr = expr.xreplace(min_max_replacements)
59555954
expr = safe_expand(expr)
59565955

59575956
# TODO it would seem that this pass is not necessary given the

0 commit comments

Comments
 (0)
0