10000 BUG: umath: Fix log1p for complex inputs. · WarrenWeckesser/numpy@2511f87 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2511f87

Browse files
BUG: umath: Fix log1p for complex inputs.
Reimplement the complex log1p function. Use the log1p trick from Theorem 4 of Goldberg's paper "What every computer scientist should know about floating-point arithmetic". Include special handling of an input with imaginary part 0.0 to ensure the sign of the imaginary part of the result is correct and consistent with the complex log function. Closes numpygh-22609.
1 parent 4ec0182 commit 2511f87

File tree

2 files changed

+147
-9
lines changed

2 files changed

+147
-9
lines changed

numpy/core/src/umath/funcs.inc.src

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,55 @@ nc_log@c@(@ctype@ *x, @ctype@ *r)
329329
static void
330330
nc_log1p@c@(@ctype@ *x, @ctype@ *r)
331331
{
332-
@ftype@ l = npy_hypot@c@(npy_creal@c@(*x) + 1,npy_cimag@c@(*x));
333-
npy_csetimag@c@(r, npy_atan2@c@(npy_cimag@c@(*x), npy_creal@c@(*x) + 1));
334-
npy_csetreal@c@(r, npy_log@c@(l));
332+
@ftype@ x_re = npy_creal@c@(*x);
333+
@ftype@ x_im = npy_cimag@c@(*x);
334+
335+
if (x_im == 0.0) {
336+
/*
337+
* Imaginary part of *x is +/- 0.0. Use the real-valued function
338+
* log1p or log to compute the real part of the result. If the
339+
* input is on the branch cut, the imaginary part of the result is
340+
* +/- pi, otherwise it is +/- 0.0 (i.e. same as x_im).
341+
*/
342+
if (npy_isnan(x_re)) {
343+
npy_csetreal@c@(r, NAN);
344+
npy_csetimag@c@(r, NAN);
345+
}
346+
else if (x_re >= -1) {
347+
/*
348+
* Note that if x_re == -1, this will generate a "divide by zero"
349+
* RuntimeWarning and set the real part of the result to -inf
350+
* (consistent with log(0.0 +/- 0.0j)).
351+
*/
352+
npy_csetreal@c@(r, npy_log1p@c@(x_re));
353+
npy_csetimag@c@(r, x_im);
354+
}
355+
else {
356+
/*
357+
* On the branch cut.
358+
*/
359+
npy_csetreal@c@(r, npy_log@c@(-(1 + x_re)));
360+
npy_csetimag@c@(r, npy_copysign@c@(NPY_PI@c@, x_im));
361+
}
362+
}
363+
else {
364+
/*
365+
* Use the log1p trick given as Theorem 4 of Goldberg's paper "What
366+
* every computer scientist should know about floating-point
367+
* arithmetic".
368+
*/
369+
@ctype@ u = *x + 1.0;
370+
if (npy_creal@c@(u) - 1.0 == x_re) {
371+
/*
372+
* Don't bother multiplying by *x / (u - 1), because that quotient
373+
* is 1, and the complex division might introduce a (small) error.
374+
*/
375+
*r = npy_clog@c@(u);
376+
}
377+
else {
378+
*r = npy_clog@c@(u) * (*x / (u - 1.0));
379+
}
380+
}
335381
return;
336382
}
337383

numpy/core/tests/test_umath.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2004,6 +2004,7 @@ def test_strided_float32(self):
20042004
assert_array_almost_equal_nulp(np.sin(x_f32_large[::jj]), sin_true[::jj], nulp=2)
20052005
assert_array_almost_equal_nulp(np.cos(x_f32_large[::jj]), cos_true[::jj], nulp=2)
20062006

2007+
20072008
class TestLogAddExp(_FilterInvalids):
20082009
def test_logaddexp_values(self):
20092010
x = [1, 2, 3, 4, 5]
@@ -2054,13 +2055,104 @@ def test_log1p(self):
20542055
assert_almost_equal(ncu.log1p(0.2), ncu.log(1.2))
20552056
assert_almost_equal(ncu.log1p(1e-6), ncu.log(1+1e-6))
20562057

2057-
def test_special(self):
2058+
# Special cases that we test for equality. The floating point warnings
2059+
# triggered by some of these values are ignored.
2060+
@pytest.mark.parametrize(
2061+
'z, wref',
2062+
[(np.nan, np.nan),
2063+
(np.inf, np.inf),
2064+
(-1.0, -np.inf),
2065+
(-2.0, np.nan),
2066+
(-np.inf, np.nan),
2067+
(complex(np.nan, 0.0), complex(np.nan, np.nan)),
2068+
(complex(np.nan, 1), complex(np.nan, np.nan)),
2069+
(complex(1, np.nan), complex(np.nan, np.nan)),
2070+
(complex(np.inf, 1), complex(np.inf, 0.0)),
2071+
(complex(np.inf, -1), complex(np.inf, -0.0)),
2072+
(complex(-np.inf, 1), complex(np.inf, np.pi)),
2073+
(complex(-np.inf, -1), complex(np.inf, -np.pi)),
2074+
(complex(-np.inf, 0.0), complex(np.inf, np.pi)),
2075+
(complex(-np.inf, -0.0), complex(np.inf, -np.pi)),
2076+
(complex(0, np.inf), complex(np.inf, np.pi/2)),
2077+
(complex(0, -np.inf), complex(np.inf, -np.pi/2)),
2078+
(complex(-1, 0), complex(-np.inf, 0))]
2079+
)
2080+
def test_special(self, z, wref):
20582081
with np.errstate(invalid="ignore", divide="ignore"):
2059-
assert_equal(ncu.log1p(np.nan), np.nan)
2060-
assert_equal(ncu.log1p(np.inf), np.inf)
2061-
assert_equal(ncu.log1p(-1.), -np.inf)
2062-
assert_equal(ncu.log1p(-2.), np.nan)
2063-
assert_equal(ncu.log1p(-np.inf), np.nan)
2082+
w = np.log1p(z)
2083+
assert_equal(w, wref)
2084+
2085+
# Test w = log1p(z) for complex z (np.complex128).
2086+
# Reference values were computed with mpmath, e.g.
2087+
# from mpmath import mp
2088+
# mp.dps = 200
2089+
# wref = complex(mp.log1p(z))
2090+
@pytest.mark.parametrize(
2091+
'z, wref',
2092+
[(1e-280 + 0j, 1e-280 + 0j),
2093+
(1e-18 + 0j, 1e-18 + 0j),
2094+
(1e-18 + 1e-12j, 1.0000005e-18 + 1e-12j),
2095+
(1e-18 + 0.1j, 0.0049751654265840425 + 0.09966865249116204j),
2096+
(-1e-15 + 3e-8j, -5.5e-16 + 3.000000000000002e-08j),
2097+
(1e-15 + 3e-8j, 1.4499999999999983e-15 + 2.999999999999996e-08j),
2098+
(1e-50 + 1e-200j, 1e-50 + 1e-200j),
2099+
(1e-200 - 1e-200j, 1e-200 - 1e-200j),
2100+
(1e-18j, 5.0000000000000005e-37 + 1e-18j),
2101+
(-4.999958e-05 - 0.009999833j,
2102+
-7.0554155328678184e-15 - 0.009999999665816696j),
2103+
(3.4259e-13 + 6.71894e-08j,
2104+
3.448472077361198e-13 + 6.718939999997688e-08j),
2105+
(0.1 + 1e-18j, 0.09531017980432487+9.090909090909091e-19j),
2106+
(-0.57113 - 0.90337j, 3.4168883248419116e-06 - 1.1275564209486122j),
2107+
(0.2 + 0.3j, 0.21263386770217205 + 0.24497866312686414j),
2108+
(1e200 + 1e200j, 460.8635921890891 + 0.7853981633974483j),
2109+
(-1 + 1e250j, 575.6462732485114 + 1.5707963267948966j),
2110+
(1e250 + 1j, 575.6462732485114 + 1e-250j),
2111+
(1e275 + 1e-225j, 633.2109005733626 + 0j),
2112+
(-0.75 + 0j, -1.3862943611198906 + 0j)],
2113+
)
2114+
def test_complex_double(self, z, wref):
2115+
w = np.log1p(z)
2116+
assert_allclose(w, wref, rtol=1e-15)
2117+
2118+
# Test w = log1p(z) for complex z (np.complex64).
2119+
# Reference values were computed with mpmath, e.g.
2120+
# from mpmath import mp
2121+
# mp.dps = 200
2122+
# wref = np.complex64(mp.log1p(z))
2123+
@pytest.mark.parametrize(
2124+
'z, wref',
2125+
[(np.complex64(1e-10 + 3e-6j), np.complex64(1.045e-10 + 3e-06j)),
2126+
(np.complex64(-1e-8 - 2e-5j), np.complex64(-9.8e-09 - 2e-05j)),
2127+
(np.complex64(-2e-32 + 3e-32j), np.complex64(-2e-32 + 3e-32j)),
2128+
(np.complex64(-0.57113 - 0.90337j),
2129+
np.complex64(3.4470238e-06 - 1.1275564j)),
2130+
(np.complex64(3e31 - 4e31j), np.complex64(72.98958 - 0.9272952j))]
2131+
)
2132+
def test_complex_single(self, z, wref):
2133+
w = np.log1p(z)
2134+
assert_allclose(w, wref, rtol=1e-6)
2135+
2136+
def test_branch_cut(self):
2137+
x = -1.5
2138+
zpos = complex(x, 0.0)
2139+
wpos = np.log1p(zpos)
2140+
assert wpos.imag == np.pi
2141+
zneg = complex(x, -0.0)
2142+
wneg = np.log1p(zneg)
2143+
assert wneg.imag == -np.pi
2144+
assert wpos.real == wneg.real
2145+
2146+
@pytest.mark.parametrize('x', [-0.5, -1e-12, -1e-18, 0.0, 1e-18, 1e-12])
2147+
def test_imag_zero(self, x):
2148+
# Test real inputs with x > -1 and the imaginary part +/- 0.0.
2149+
zpos = complex(x, 0.0)
2150+
wpos = np.log1p(zpos)
2151+
assert_equal(wpos.imag, 0.0)
2152+
zneg = complex(x, -0.0)
2153+
wneg = np.log1p(zneg)
2154+
assert_equal(wneg.imag, -0.0)
2155+
assert wpos.real == wneg.real
20642156

20652157

20662158
class TestExpm1:

0 commit comments

Comments
 (0)
0