8000 Pin `sympy >= 1.13.0` (#130895) · XuehaiPan/pytorch@97c046d · GitHub
[go: up one dir, main page]

Skip to content

Commit 97c046d

Browse files
committed
Pin sympy >= 1.13.0 (pytorch#130895)
------ The opposite of pytorch#130836. Pin `sympy >= 1.13.0` for Python >= 3.9 and `sympy == 1.12.1` for Python 3.8. - pytorch#130836 See the PR description of pytorch#130836 for more details. `sympy` 1.13.0 introduces some breaking changes which break our tests. More specifically: - Ref [Backwards compatibility breaks and deprecations](https://github.com/sympy/sympy/wiki/release-notes-for-1.13.0#backwards-compatibility-breaks-and-deprecations) > BREAKING CHANGE: Float and Integer/Rational no longer compare equal with a == b. From now on Float(2.0) != Integer(2). Previously expressions involving Float would compare unequal e.g. x*2.0 != x*2 but an individual Float would compare equal to an Integer. In SymPy 1.7 a Float will always compare unequal to an Integer even if they have the same "value". Use sympy.numbers.int_valued(number) to test if a number is a concrete number with no decimal part. ([pytorch#25614](sympy/sympy#25614) by [@smichr](https://github.com/smichr)) `sympy >= 1.13.0` is required to enable Python 3.13 support. This should be part of pytorch#130689. - pytorch#130689 Pull Request resolved: pytorch#130895 Approved by: https://github.com/ezyang
1 parent d990dad commit 97c046d

File tree

8 files changed

+77
-40
lines changed

8 files changed

+77
-40
lines changed

.circleci/scripts/binary_linux_test.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,18 @@ if [[ "\$python_nodot" = *310* ]]; then
4646
PROTOBUF_PACKAGE="protobuf>=3.19.0"
4747
fi
4848
49-
if [[ "\$python_nodot" = *39* ]]; then
49+
if [[ "\$python_nodot" = *39* ]]; then
5050
# There's an issue with conda channel priority where it'll randomly pick 1.19 over 1.20
5151
# we set a lower boundary here just to be safe
5252
NUMPY_PIN=">=1.20"
5353
fi
5454
55-
55+
if [[ "\$python_nodot" = *38* ]]; then
56+
# sympy 1.12.1 is the last version that supports Python 3.8
57+
SYMPY_PIN="==1.12.1"
58+
else
59+
SYMPY_PIN=">=1.13.0"
60+
fi
5661
5762
# Move debug wheels out of the package dir so they don't get installed
5863
mkdir -p /tmp/debug_final_pkgs
@@ -83,7 +88,7 @@ if [[ "$PACKAGE_TYPE" == conda ]]; then
8388
"numpy\${NUMPY_PIN}" \
8489
mkl>=2018 \
8590
ninja \
86-
sympy \
91+
"sympy\${SYMPY_PIN}" \
8792
typing-extensions \
8893
${PROTOBUF_PACKAGE}
8994
if [[ "$DESIRED_CUDA" == 'cpu' ]]; then

.github/requirements/pip-requirements-macOS.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ pytest-xdist==3.3.1
1717
pytest-rerunfailures==10.3
1818
pytest-flakefinder==1.1.0
1919
scipy==1.10.1
20-
sympy==1.11.1
20+
sympy==1.12.1 ; python_version == "3.8"
21+
sympy>=1.13.0 ; python_version >= "3.9"
2122
unittest-xml-reporting<=3.2.0,>=2.0.0
2223
xdoctest==1.1.0
2324
filelock==3.6.0
24-
sympy==1.11.1
2525
pytest-cpp==2.3.0
2626
rockset==1.0.3
2727
z3-solver==4.12.2.0

.lintrunner.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ init_command = [
137137
'numpy==1.26.0 ; python_version >= "3.9"',
138138
'expecttest==0.1.6',
139139
'mypy==1.9.0',
140-
'sympy==1.11.1',
140+
'sympy==1.12.1 ; python_version == "3.8"',
141+
'sympy==1.13.0 ; python_version >= "3.9"',
141142
'types-requests==2.27.25',
142143
'types-PyYAML==6.0.7',
143144
'types-tabulate==0.8.8',

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ requests
99
setuptools
1010
types-dataclasses
1111
typing-extensions>=4.8.0
12-
sympy
12+
sympy==1.12.1 ; python_version == "3.8"
13+
sympy>=1.13.0 ; python_version >= "3.9"
1314
filelock
1415
networkx
1516
jinja2

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,8 @@ def main():
11281128
install_requires = [
11291129
"filelock",
11301130
"typing-extensions>=4.8.0",
1131-
"sympy",
1131+
'sympy==1.12.1 ; python_version == "3.8"',
1132+
'sympy>=1.13.0 ; python_version >= "3.9"',
11321133
"networkx",
11331134
"jinja2",
11341135
"fsspec",

torch/_inductor/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def pformat(obj: Any) -> str:
638638
obj = sorted(obj, key=str)
639639
result = pprint.pformat(obj, indent=4)
640640
if "\n" in result:
641-
return f"\n{textwrap.indent(result, ' '*4)}"
641+
return f"\n{textwrap.indent(result, ' ' * 4)}"
642642
return result
643643

644644

@@ -1587,7 +1587,7 @@ def add_user(
15871587
NodeUser(user_node, can_inplace, is_weak)
15881588
)
15891589

1590-
unbacked_symbol_to_origin_node = {}
1590+
unbacked_symbol_to_origin_node: Dict[sympy.Symbol, Optional[str]] = {}
15911591

15921592
# NB: None means that the dependency is on an input. Don't actually
15931593
# generate a dependency because if we do, Inductor will start trying

torch/utils/_sympy/functions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import sympy
77
from sympy import S
8+
from sympy.core.numbers import equal_valued
89

910
__all__ = [
1011
"FloorDiv",
@@ -104,9 +105,9 @@ def eval(cls, base, divisor):
104105

105106
if base.is_zero:
106107
return sympy.S.Zero
107-
if base.is_integer and divisor == 1:
108+
if base.is_integer and equal_valued(divisor, 1):
108109
return base
109-
if base.is_integer and divisor == -1:
110+
if base.is_integer and equal_valued(divisor, -1):
110111
return sympy.Mul(base, -1)
111112
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
112113
return sympy.Integer(int(base) // int(divisor))
@@ -125,7 +126,7 @@ def eval(cls, base, divisor):
125126

126127
try:
127128
gcd = sympy.gcd(base, divisor)
128-
if gcd != 1:
129+
if not equal_valued(gcd, 1):
129130
return FloorDiv(
130131
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
131132
)

torch/utils/_sympy/value_ranges.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,19 @@ def __repr__(self) -> str:
132132
return f"VR[{self.lower}, {self.upper}]"
133133

134134
@overload
135-
def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None:
135+
def __init__(
136+
self: ValueRanges[sympy.Expr],
137+
lower: ExprIn,
138+
upper: ExprIn,
139+
) -> None:
136140
...
137141

138142
@overload
139-
def __init__(self: ValueRanges[SympyBoolean], lower: BoolIn, upper: BoolIn) -> None:
143+
def __init__( # type: ignore[misc]
144+
self: ValueRanges[SympyBoolean],
145+
lower: BoolIn,
146+
upper: BoolIn,
147+
) -> None:
140148
...
141149

142150
def __init__(self, lower: AllIn, upper: AllIn) -> None:
@@ -149,26 +157,31 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None:
149157
raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
150158
except TypeError as e:
151159
raise TypeError(f"Could not compare {lower} <= {upper}") from e
152-
# Because this is a frozen class
153-
object.__setattr__(self, "lower", lower)
154-
object.__setattr__(self, "upper", upper)
155-
# Unlike bool/int in Python, we don't report bools are ints
156-
object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean))
157-
if self.is_bool:
158-
assert isinstance(upper, SympyBoolean), (lower, upper)
160+
161+
is_bool_lower = isinstance(lower, SympyBoolean)
162+
is_bool_upper = isinstance(upper, SympyBoolean)
163+
assert is_bool_lower == is_bool_upper, (lower, upper)
159164

160165
# Warning: is_int/is_float is best effort. We do pretty well in
161166
# Dynamo, but in Inductor these attributes are often wrong because we
162167
# are not very rigorous in dtype analysis. This is also why we need
163168
# the flexible analysis for is_int: sometimes a sympy.oo pops in for
164169
# an integer bound. I would /like/ for us not to do this, but it's
165170
# too hard to push the invariant through right now.
171+
is_int_lower = isinstance(lower, sympy.Integer)
172+
is_int_upper = isinstance(upper, sympy.Integer)
166173

174+
# Because this is a frozen class
175+
object.__setattr__(self, "lower", lower)
176+
object.__setattr__(self, "upper", upper)
177+
# Unlike bool/int in Python, we don't report bools are ints
178+
#
179+
# NB: is_bool_lower == is_bool_upper, so we only need to check one
180+
object.__setattr__(self, "is_bool", is_bool_lower)
167181
object.__setattr__(
168182
self,
169183
"is_int",
170-
not self.is_bool
171-
and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)),
184+
not self.is_bool and (is_int_lower or is_int_upper),
172185
)
173186
"""
174187
# This assert is just impossible right now, too many sympy bugs
@@ -209,13 +222,15 @@ def tighten(self, other) -> ValueRanges:
209222
# Intersection
210223
@overload
211224
def __and__(
212-
self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr]
225+
self: ValueRanges[sympy.Expr],
226+
other: ValueRanges[sympy.Expr],
213227
) -> ValueRanges[sympy.Expr]:
214228
...
215229

216230
@overload
217-
def __and__(
218-
self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean]
231+
def __and__( # type: ignore[misc]
232+
self: ValueRanges[SympyBoolean],
233+
other: ValueRanges[SympyBoolean],
219234
) -> ValueRanges[SympyBoolean]:
220235
...
221236

@@ -239,20 +254,24 @@ def __and__(self: AllVR, other: AllVR) -> AllVR:
239254
# Union
240255
@overload
241256
def __or__(
242-
self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr]
257+
self: ValueRanges[sympy.Expr],
258+
other: ValueRanges[sympy.Expr],
243259
) -> ValueRanges[sympy.Expr]:
244260
...
245261

246262
@overload
247-
def __or__(
248-
self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean]
263+
def __or__( # type: ignore[misc]
264+
self: ValueRanges[SympyBoolean],
265+
other: ValueRanges[SympyBoolean],
249266
) -> ValueRanges[SympyBoolean]:
250267
...
251268

252269
def __or__(self: AllVR, other: AllVR) -> AllVR:
253270
if ValueRanges.unknown() in (self, other):
254271
return ValueRanges.unknown()
255272
assert self.is_bool == other.is_bool, (self, other)
273+
assert self.is_int == other.is_int, (self, other)
274+
assert self.is_float == other.is_float, (self, other)
256275
if self.is_bool:
257276
return ValueRanges(
258277
sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper)
@@ -282,7 +301,7 @@ def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR: # type: ignore[overload-overlap
282301

283302
@overload
284303
@staticmethod
285-
def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR:
304+
def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR: # type: ignore[misc]
286305
...
287306

288307
@staticmethod
@@ -307,7 +326,7 @@ def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
307326

308327
@overload
309328
@staticmethod
310-
def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR:
329+
def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR: # type: ignore[misc]
311330
...
312331

313332
@staticmethod
@@ -330,27 +349,36 @@ def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
330349
"""Fn is convex and has a minimum at 0."""
331350
x = ValueRanges.wrap(x)
332351
if 0 in x:
333-
return ValueRanges(0, max(fn(x.lower), fn(x.upper)))
334-
else:
335-
return ValueRanges.monotone_map(x, fn)
352+
upper = max(fn(x.lower), fn(x.upper))
353+
upper = simple_sympify(upper)
354+
if isinstance(upper, sympy.Float) or upper == sympy.oo:
355+
return ValueRanges(0.0, upper)
356+
return ValueRanges(0, upper)
357+
return ValueRanges.monotone_map(x, fn)
336358

337359
@overload
338360
@staticmethod
339361
def coordinatewise_increasing_map(
340-
x: Union[ExprIn, ExprVR], y: Union[ExprIn, ExprVR], fn: ExprFn2
362+
x: Union[ExprIn, ExprVR],
363+
y: Union[ExprIn, ExprVR],
364+
fn: ExprFn2,
341365
) -> ExprVR:
342366
...
343367

344368
@overload
345369
@staticmethod
346-
def coordinatewise_increasing_map(
347-
x: Union[BoolIn, BoolVR], y: Union[BoolIn, BoolVR], fn: BoolFn2
370+
def coordinatewise_increasing_map( # type: ignore[misc]
371+
x: Union[BoolIn, BoolVR],
372+
y: Union[BoolIn, BoolVR],
373+
fn: BoolFn2,
348374
) -> BoolVR:
349375
...
350376

351377
@staticmethod
352378
def coordinatewise_increasing_map(
353-
x: Union[AllIn, AllVR], y: Union[AllIn, AllVR], fn: AllFn2
379+
x: Union[AllIn, AllVR],
380+
y: Union[AllIn, AllVR],
381+
fn: AllFn2,
354382
) -> AllVR:
355383
"""
356384
It's increasing on each coordinate.
@@ -1001,7 +1029,7 @@ def bound_sympy(
10011029
if unbounded_vars:
10021030
# Give some bounds to the free variables via their SymPy assumptions
10031031
# TODO A better way of doing this would be to assign them a range upon creation, as
1004-
# size variables can come with a lower bound of 2, as we specialise on 0 and 1
1032+
# size variables can come with a lower bound of 2, as we specialize on 0 and 1
10051033
unbounded_ranges: Dict[sympy.Symbol, ValueRanges] = {}
10061034
for s in unbounded_vars:
10071035
if s.is_integer: # type: ignore[attr-defined]

0 commit comments

Comments
 (0)
0