8000 BUG: core: make complex division by zero to yield inf properly (#1776) · numpy/numpy@65b77ee · GitHub
[go: up one dir, main page]

Skip to content

Commit 65b77ee

Browse files
committed
BUG: core: make complex division by zero to yield inf properly (#1776)
1 parent b8101c9 commit 65b77ee

File tree

4 files changed

+64
-9
lines changed

4 files changed

+64
-9
lines changed

numpy/core/src/scalarmathmodule.c.src

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,17 @@ static npy_half (*_basic_half_fmod)(npy_half, npy_half);
382382
(outp)->real = (a).real * (b).real - (a).imag * (b).imag; \
383383
(outp)->imag = (a).real * (b).imag + (a).imag * (b).real; \
384384
} while(0)
385-
#define @name@_ctype_divide(a, b, outp) do{ \
386-
@rtype@ d = (b).real*(b).real + (b).imag*(b).imag; \
387-
(outp)->real = ((a).real*(b).real + (a).imag*(b).imag)/d; \
388-
(outp)->imag = ((a).imag*(b).real - (a).real*(b).imag)/d; \
385+
/* Note: complex division by zero must yield some complex inf */
386+
#define @name@_ctype_divide(a, b, outp) do{ \
387+
@rtype@ d = (b).real*(b).real + (b).imag*(b).imag; \
388+
if (d != 0) { \
389+
(outp)->real = ((a).real*(b).real + (a).imag*(b).imag)/d; \
390+
(outp)->imag = ((a).imag*(b).real - (a).real*(b).imag)/d; \
391+
} \
392+
else { \
393+
(outp)->real = (a).real/d; \
394+
(outp)->imag = (a).imag/d; \
395+
} \
389396
} while(0)
390397
#define @name@_ctype_true_divide @name@_ctype_divide
391398
#define @name@_ctype_floor_divide(a, b, outp) do { \

numpy/core/src/umath/loops.c.src

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,11 +1804,20 @@ C@TYPE@_divide(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func
18041804
const @type@ in1i = ((@type@ *)ip1)[1];
18051805
const @type@ in2r = ((@type@ *)ip2)[0];
18061806
const @type@ in2i = ((@type@ *)ip2)[1];
1807-
if (npy_fabs@c@(in2r) >= npy_fabs@c@(in2i)) {
1808-
const @type@ rat = in2i/in2r;
1809-
const @type@ scl = 1.0@c@/(in2r + in2i*rat);
1810-
((@type@ *)op1)[0] = (in1r + in1i*rat)*scl;
1811-
((@type@ *)op1)[1] = (in1i - in1r*rat)*scl;
1807+
const @type@ in2r_abs = npy_fabs@c@(in2r);
1808+
const @type@ in2i_abs = npy_fabs@c@(in2i);
1809+
if (in2r_abs >= in2i_abs) {
1810+
if (in2r_abs == 0 && in2i_abs == 0) {
1811+
/* divide by zero should yield a complex inf or nan */
1812+
((@type@ *)op1)[0] = in1r/in2r_abs;
1813+
((@type@ *)op1)[1] = in1i/in2i_abs;
1814+
}
1815+
else {
1816+
const @type@ rat = in2i/in2r;
1817+
const @type@ scl = 1.0@c@/(in2r + in2i*rat);
1818+
((@type@ *)op1)[0] = (in1r + in1i*rat)*scl;
1819+
((@type@ *)op1)[1] = (in1i - in1r*rat)*scl;
1820+
}
18121821
}
18131822
else {
18141823
const @type@ rat = in2r/in2i;

numpy/core/tests/test_scalarmath.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,28 @@ def test_large_types(self):
6060
assert_almost_equal(b, 6765201, err_msg=msg)
6161

6262

63+
class TestComplexDivision(TestCase):
64+
def test_zero_division(self):
65+
err = np.seterr(over="ignore")
66+
try:
67+
for t in [np.complex64, np.complex128]:
68+
a = t(0.0)
69+
b = t(1.0)
70+
assert_(np.isinf(b/a))
71+
b = t(complex(np.inf, np.inf))
72+
assert_(np.isinf(b/a))
73+
b = t(complex(np.inf, np.nan))
74+
assert_(np.isinf(b/a))
75+
b = t(complex(np.nan, np.inf))
76+
assert_(np.isinf(b/a))
77+
b = t(complex(np.nan, np.nan))
78+
assert_(np.isnan(b/a))
79+
b = t(0.)
80+
assert_(np.isnan(b/a))
81+
finally:
82+
np.seterr(**err)
83+
84+
6385
class TestConversion(TestCase):
6486
def test_int_from_long(self):
6587
l = [1e6, 1e12, 1e18, -1e6, -1e12, -1e18]

numpy/core/tests/test_umath.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff lin 57AE e change
@@ -28,6 +28,23 @@ def test_division_complex(self):
2828
y = x**2/x
2929
assert_almost_equal(y/x, [1, 1], err_msg=msg)
3030

31+
def test_zero_division_complex(self):
32+
err = np.seterr(invalid="ignore")
33+
try:
34+
x = np.array([0.0], dtype=np.complex128)
35+
y = 1.0/x
36+
assert_(np.isinf(y)[0])
37+
y = complex(np.inf, np.nan)/x
38+
assert_(np.isinf(y)[0])
39+
y = complex(np.nan, np.inf)/x
40+
assert_(np.isinf(y)[0])
41+
y = complex(np.inf, np.inf)/x
42+
assert_(np.isinf(y)[0])
43+
y = 0.0/x
44+
assert_(np.isnan(y)[0])
45+
finally:
46+
np.seterr(**err)
47+
3148
def test_floor_division_complex(self):
3249
# check that implementation is correct
3350
msg = "Complex floor division implementation check"

0 commit comments

Comments
 (0)
0