8000 Keep zero check be compatible with different sympy versions (#130729) · pytorch/pytorch@096dc44 · GitHub
[go: up one dir, main page]

Skip to content

Commit 096dc44

Browse files
guangyeypytorchmergebot
authored andcommitted
Keep zero check be compatible with different sympy versions (#130729)
# Motivation I found a difference between sympy 1.12 and 1.13. ```python # for 1.12 >>> import sympy >>> a = sympy.Number(0.0) >>> a == 0 True ``` ```python # for 1.13 >>> import sympy >>> a = sympy.Number(0.0) >>> a == 0 False ``` The different behavior will impact the result of [safe_mul](https://github.com/pytorch/pytorch/blob/6beec34b1c6803d5f6648c3cd7c262d6432374c8/torch/utils/_sympy/value_ranges.py#L521-L528), resulting in an incorrect results when `a = sympy.Number(0.0)`, `b = inf` and the result is `nan` if sympy version is 1.13. (the expected result is **0**) ```python def safe_mul(a, b): # Make unknown() * wrap(0.0) == wrap(0.0) if a == 0.0: return a elif b == 0.0: return b else: return a * b ``` In different sympy versions, `sympy.Number(0)` always has the same behavior that equals to 0.0. ```python >>> import sympy >>> a = sympy.Number(0) >>> a == 0.0 True # for different sympy versions ``` So, use 0.0 when checking zero in safe_mul to keep compatible with different sympy versions. Pull Request resolved: #130729 Approved by: https://github.com/lezcano, https://github.com/EikanWang
1 parent fedae41 commit 096dc44

File tree

2 files changed

+7
-3
lines changed

2 files changed

+ 8000 7
-3
lines changed

test/test_sympy_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ def test_mul_zero_unknown(self):
241241
ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()),
242242
ValueRanges.wrap(0),
243243
)
244+
self.assertEqual(
245+
ValueRangeAnalysis.mul(ValueRanges.wrap(0.0), ValueRanges.unknown()),
246+
ValueRanges.wrap(0.0),
247+
)
244248

245249
@parametrize("fn", UNARY_BOOL_OPS)
246250
def test_unary_bool_ref_range(self, fn):

torch/utils/_sympy/value_ranges.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,10 +519,10 @@ def mul(cls, a, b):
519519
return cls.and_(a, b)
520520

521521
def safe_mul(a, b):
522-
# Make unknown() * wrap(0) == wrap(0)
523-
if a == 0:
522+
# Make unknown() * wrap(0.0) == wrap(0.0)
523+
if a == 0.0:
524524
return a
525-
elif b == 0:
525+
elif b == 0.0:
526526
return b
527527
else:
528528
return a * b

0 commit comments

Comments
 (0)
0