8000 BUG: ticket #1776, make complex division by zero to yield inf properly. · certik/numpy@88e7e52 · GitHub
[go: up one dir, main page]

Skip to content

Commit 88e7e52

Browse files
pvcharris
authored andcommitted
BUG: ticket numpy#1776, make complex division by zero to yield inf properly.
1 parent 5f940ca commit 88e7e52

File tree

4 files changed

+64
-9
lines changed

4 files changed

+64
-9
lines changed

numpy/core/src/scalarmathmodule.c.src

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,17 @@ static npy_half (*_basic_half_fmod)(npy_half, npy_half);
382382
(outp)->real = (a).real * (b).real - (a).imag * (b).imag; \
383383
(outp)->imag = (a).real * (b).imag + (a).imag * (b).real; \
384384
} while(0)
385-
#define @name@_ctype_divide(a, b, outp) do{ \
386-
@rtype@ d = (b).real*(b).real + (b).imag*(b).imag; \
387-
(outp)->real = ((a).real*(b).real + (a).imag*(b).imag)/d; \
388-
(outp)->imag = ((a).imag*(b).real - (a).real*(b).imag)/d; \
385+
/* Note: complex division by zero must yield some complex inf */
386+
#define @name@_ctype_divide(a, b, outp) do{ \
387+
@rtype@ d = (b).real*(b).real + (b).imag*(b).imag; \
388+
if (d != 0) { \
389+
(outp)->real = ((a).real*(b).real + (a).imag*(b).imag)/d; \
390+
(outp)->imag = ((a).imag*(b).real - (a).real*(b).imag)/d; \
391+
} \
392+
else { \
393+
(outp)->real = (a).real/d; \
394+
(outp)->imag = (a).imag/d; \
395+
} \
389396
} while(0)
390397
#define @name@_ctype_true_divide @name@_ctype_divide
391398
#define @name@_ctype_floor_divide(a, b, outp) do { \

numpy/core/src/umath/loops.c.src

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,11 +1804,20 @@ C@TYPE@_divide(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func
18041804
const @type@ in1i = ((@type@ *)ip1)[1];
18051805
const @type@ in2r = ((@type@ *)ip2)[0];
18061806
const @type@ in2i = ((@type@ *)ip2)[1];
1807-
if (npy_fabs@c@(in2r) >= npy_fabs@c@(in2i)) {
1808-
const @type@ rat = in2i/in2r;
1809-
const @type@ scl = 1.0@c@/(in2r + in2i*rat);
1810-
((@type@ *)op1)[0] = (in1r + in1i*rat)*scl;
1811-
((@type@ *)op1)[1] = (in1i - in1r*rat)*scl;
1807+
const @type@ in2r_abs = npy_fabs@c@(in2r);
1808+
const @type@ in2i_abs = npy_fabs@c@(in2i);
1809+
if (in2r_abs >= in2i_abs) {
1810+
if (in2r_abs == 0 && in2i_abs == 0) {
1811+
/* divide by zero should yield a complex inf or nan */
1812+
((@type@ *)op1)[0] = in1r/in2r_abs;
1813+
((@type@ *)op1)[1] = in1i/in2i_abs;
1814+
}
1815+
else {
1816+
const @type@ rat = in2i/in2r;
1817+
const @type@ scl = 1.0@c@/(in2r + in2i*rat);
1818+
((@type@ *)op1)[0] = (in1r + in1i*rat)*scl;
1819+
((@type@ *)op1)[1] = (in1i - in1r*rat)*scl;
1820+
}
18121821
}
18131822
else {
18141823
const @type@ rat = in2r/in2i;

numpy/core/tests/test_scalarmath.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,28 @@ def test_large_types(self):
6060
assert_almost_equal(b, 6765201, err_msg=msg)
6161

6262

63+
class TestComplexDivision(TestCase):
64+
def test_zero_division(self):
65+
err = np.seterr(over="ignore")
66+
try:
67+
for t in [np.complex64, np.complex128]:
68+
a = t(0.0)
69+
b = t(1.0)
70+
assert_(np.isinf(b/a))
71+
b = t(complex(np.inf, np.inf))
72+
assert_(np.isinf(b/a))
73+
b = t(complex(np.inf, np.nan))
74+
assert_(np.isinf(b/a))
75+
b = t(complex(np.nan, np.inf))
76+
assert_(np.isinf(b/a))
77+
b = t(complex(np.nan, np.nan))
78+
assert_(np.isnan(b/a))
79+
b = t(0.)
80+
assert_(np.isnan(b/a))
81+
finally:
82+
np.seterr(**err)
83+
84+
6385
class TestConversion(TestCase):
6486
def test_int_from_long(self):
6587
l = [1e6, 1e12, 1e18, -1e6, -1e12, -1e18]

numpy/core/tests/test_umath.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,23 @@ def test_division_complex(self):
2929
y = x**2/x
3030
assert_almost_equal(y/x, [1, 1], err_msg=msg)
3131

32+
def test_zero_division_complex(self):
33+
err = np.seterr(invalid="ignore")
34+
try:
35+
x = np.array([0.0], dtype=np.complex128)
36+
y = 1.0/x
37+
assert_(np.isinf(y)[0])
38+
y = complex(np.inf, np.nan)/x
39+
assert_(np.isinf(y)[0])
40+
y = complex(np.nan, np.inf)/x
41+
assert_(np.isinf(y)[0])
42+
y = complex(np.inf, np.inf)/x
43+
assert_(np.isinf(y)[0])
44+
y = 0.0/x
45+
assert_(np.isnan(y)[0])
46+
finally:
47+
np.seterr(**err)
48+
3249
def test_floor_division_complex(self):
3350
# check that implementation is correct
3451
msg = "Complex floor division implementation check"

0 commit comments

Comments
 (0)
0