8000 Rewrite complex log1p to improve precision. · WarrenWeckesser/numpy@5c2a847 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5c2a847

Browse files
Rewrite complex log1p to improve precision.
Closes numpygh-22609 [skip circle]
1 parent ccff7fb commit 5c2a847

File tree

7 files changed

+447
-15
lines changed

7 files changed

+447
-15
lines changed

numpy/_core/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,7 @@ src_umath = umath_gen_headers + [
11531153
src_file.process('src/umath/scalarmath.c.src'),
11541154
'src/umath/ufunc_object.c',
11551155
'src/umath/umathmodule.c',
1156+
src_file.process('src/umath/log1p_complex_wrappers.cpp.src'),
11561157
'src/umath/special_integer_comparisons.cpp',
11571158
'src/umath/string_ufuncs.cpp',
11581159
'src/umath/stringdtype_ufuncs.cpp',

numpy/_core/src/umath/funcs.inc.src

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -326,15 +326,6 @@ nc_log@c@(@ctype@ *x, @ctype@ *r)
326326
return;
327327
}
328328

329-
static void
330-
nc_log1p@c@(@ctype@ *x, @ctype@ *r)
331-
{
332-
@ftype@ l = npy_hypot@c@(npy_creal@c@(*x) + 1,npy_cimag@c@(*x));
333-
npy_csetimag@c@(r, npy_atan2@c@(npy_cimag@c@(*x), npy_creal@c@(*x) + 1));
334-
npy_csetreal@c@(r, npy_log@c@(l));
335-
return;
336-
}
337-
338329
static void
339330
nc_exp@c@(@ctype@ *x, @ctype@ *r)
340331
{

numpy/_core/src/umath/log1p_complex.h

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#ifndef LOG1P_COMPLEX_H
2+
#define LOG1P_COMPLEX_H
3+
4+
#include <Python.h>
5+
#include "numpy/ndarraytypes.h"
6+
#include "numpy/npy_math.h"
7+
8+
#include <cmath>
9+
#include <complex>
10+
#include <limits>
11+
12+
// For memcpy
13+
#include <cstring>
14+
15+
16+
//
17+
// Trivial C++ wrappers for several npy_* functions.
18+
//
19+
#define CPP_WRAP1(name) \
20+
inline npy_float name(const npy_float x) \
21+
{ \
22+
return ::npy_ ## name ## f(x); \
23+
} \
24+
inline npy_double name(const npy_double x) \
25+
{ \
26+
return ::npy_ ## name(x); \
27+
} \
28+
inline npy_longdouble name(const npy_longdouble x) \
29+
{ \
30+
67F4 return ::npy_ ## name ## l(x); \
31+
} \
32+
33+
34+
#define CPP_WRAP2(name) \
35+
inline npy_float name(const npy_float x, \
36+
const npy_float y) \
37+
{ \
38+
return ::npy_ ## name ## f(x, y); \
39+
} \
40+
inline npy_double name(const npy_double x, \
41+
const npy_double y) \
42+
{ \
43+
return ::npy_ ## name(x, y); \
44+
} \
45+
inline npy_longdouble name(const npy_longdouble x, \
46+
const npy_longdouble y) \
47+
{ \
48+
return ::npy_ ## name ## l(x, y); \
49+
} \
50+
51+
namespace npy {
52+
53+
CPP_WRAP1(fabs)
54+
CPP_WRAP1(log)
55+
CPP_WRAP1(log1p)
56+
CPP_WRAP2(atan2)
57+
CPP_WRAP2(hypot)
58+
59+
}
60+
61+
namespace log1p_complex
62+
{
63+
64+
template<typename T>
65+
struct doubled_t {
66+
T upper;
67+
T lower;
68+
};
69+
70+
//
71+
// Dekker splitting. See, for example, Theorem 1 of
72+
//
73+
// Seppa Linnainmaa, Software for Double-Precision Floating-Point
74+
// Computations, ACM Transactions on Mathematical Software, Vol 7, No 3,
75+
// September 1981, pages 272-283.
76+
//
77+
// or Theorem 17 of
78+
//
79+
// J. R. Shewchuk, Adaptive Precision Floating-Point Arithmetic and
80+
// Fast Robust Geometric Predicates, CMU-CS-96-140R, from Discrete &
81+
// Computational Geometry 18(3):305-363, October 1997.
82+
//
83+
template<typename T>
84+
inline void
85+
split(T x, doubled_t<T>& out)
86+
{
87+
if (std::numeric_limits<T>::digits == 106) {
88+
// Special case: IBM double-double format. The value is already
89+
// split in memory, so there is no need for any calculations.
90+
std::memcpy(&out, &x, sizeof(out));
91+
}
92+
else {
93+
constexpr int halfprec = (std::numeric_limits<T>::digits + 1)/2;
94+
T t = ((1ul << halfprec) + 1)*x;
95+
// The compiler must not be allowed to simplify this expression:
96+
out.upper = t - (t - x);
97+
out.lower = x - out.upper;
98+
}
99+
}
100+
101+
template<typename T>
102+
inline void
103+
two_sum_quick(T x, T y, doubled_t<T>& out)
104+
{
105+
T r = x + y;
106+
T e = y - (r - x);
107+
out.upper = r;
108+
out.lower = e;
109+
}
110+
111+
template<typename T>
112+
inline void
113+
two_sum(T x, T y, doubled_t<T>& out)
114+
{
115+
T s = x + y;
116+
T v = s - x;
117+
T e = (x - (s - v)) + (y - v);
118+
out.upper = s;
119+
out.lower = e;
120+
}
121+
122+
template<typename T>
123+
inline void
124+
double_sum(const doubled_t<T>& x, const doubled_t<T>& y,
125+
doubled_t<T>& out)
126+
{
127+
two_sum<T>(x.upper, y.upper, out);
128+
out.lower += x.lower + y.lower;
129+
two_sum_quick<T>(out.upper, out.lower, out);
130+
}
131+
132+
template<typename T>
133+
inline void
134+
square(T x, doubled_t<T>& out)
135+
{
136+
doubled_t<T> xsplit;
137+
out.upper = x*x;
138+
split(x, xsplit);
139+
out.lower = xsplit.lower*xsplit.lower
140+
- ((out.upper - xsplit.upper*xsplit.upper)
141+
- 2*xsplit.lower*xsplit.upper);
142+
}
143+
144+
//
145+
// As the name makes clear, this function computes x**2 + 2*x + y**2.
146+
// It uses doubled_t<T> for the intermediate calculations.
147+
// (That is, we give the floating point type T an upgrayedd, spelled with
148+
// two d's for a double dose of precision.)
149+
//
150+
// The function is used in log1p_complex() to avoid the loss of
151+
// precision that can occur in the expression when x**2 + y**2 ≈ -2*x.
152+
//
153+
template<typename T>
154+
inline T
155+
xsquared_plus_2x_plus_ysquared_dd(T x, T y)
156+
{
157+
doubled_t<T> x2, y2, twox, sum1, sum2;
158+
159+
square<T>(x, x2); // x2 = x**2
160+
square<T>(y, y2); // y2 = y**2
161+
twox.upper = 2*x; // twox = 2*x
162+
twox.lower = 0.0;
163+
double_sum<T>(x2, twox, sum1); // sum1 = x**2 + 2*x
164+
double_sum<T>(sum1, y2, sum2); // sum2 = x**2 + 2*x + y**2
165+
return sum2.upper;
166+
}
167+
168+
//
169+
// For the float type, the intermediate calculation is done
170+
// with the double type. We don't need to use doubled_t<float>.
171+
//
172+
inline float
173+
xsquared_plus_2x_plus_ysquared(float x, float y)
174+
{
175+
double xd = x;
176+
double yd = y;
177+
return xd*(2.0 + xd) + yd*yd;
178+
}
179+
180+
//
181+
// For double, we used doubled_t<double> if long double doesn't have
182+
// at least 106 bits of precision.
183+
//
184+
inline double
185+
xsquared_plus_2x_plus_ysquared(double x, double y)
186+
{
187+
if (std::numeric_limits<long double>::digits >= 106) {
188+
// Cast to long double for the calculation.
189+
long double xd = x;
190+
long double yd = y;
191+
return xd*(2.0L + xd) + yd*yd;
192+
}
193+
else {
194+
// Use doubled_t<double> for the calculation.
195+
return xsquared_plus_2x_plus_ysquared_dd<double>(x, y);
196+
}
197+
}
198+
199+
//
200+
// For long double, we always use doubled_t<long double> for the
201+
// calculation.
202+
//
203+
inline long double
204+
xsquared_plus_2x_plus_ysquared(long double x, long double y)
205+
{
206+
return xsquared_plus_2x_plus_ysquared_dd<long double>(x, y);
207+
}
208+
209+
//
210+
// Implement log1p(z) for complex inputs, using higher precision near
211+
// the unit circle |z + 1| = 1 to avoid loss of precision that can occur
212+
// when x**2 + y**2 ≈ -2*x.
213+
//
214+
// If the real or imaginary part of z is NAN, complex(NAN, NAN) is returned.
215+
//
216+
template<typename T>
217+
inline std::complex<T>
218+
log1p_complex(std::complex<T> z)
219+
{
220+
T lnr;
221+
T x = z.real();
222+
T y = z.imag();
223+
if (std::isnan(x) || std::isnan(y)) {
224+
return std::complex<T>(NPY_NAN, NPY_NAN);
225+
}
226+
if (x > -2.2 && x < 0.2 && y > -1.2 && y < 1.2
227+
&& npy::fabs(x*(2 + x) + y*y) < 0.4) {
228+
// The input is close to the unit circle centered at -1+0j.
229+
// Compute x**2 + 2*x + y**2 with higher precision than T.
230+
// The calculation here is equivalent to log(hypot(x+1, y)),
231+
// since
232+
// log(hypot(x+1, y)) = 0.5*log(x**2 + 2*x + 1 + y**2)
233+
// = 0.5*log1p(x**2 + 2*x + y**2)
234+
T t = xsquared_plus_2x_plus_ysquared(x, y);
235+
lnr = 0.5*npy::log1p(t);
236+
}
237+
else {
238+
lnr = npy::log(npy::hypot(x + static_cast<T>(1), y));
239+
}
240+
return std::complex<T>(lnr, npy::atan2(y, x + static_cast<T>(1)));
241+
}
242+
243+
} // namespace log1p_complex
244+
245+
#endif
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include <Python.h>
2+
#include "numpy/npy_math.h"
3+
#include "numpy/ndarraytypes.h"
4+
#include <complex>
5+
#include "log1p_complex.h"
6+
7+
extern "C" {
8+
9+
/**begin repeat
10+
* #fptype = float, double, long double#
11+
* #ctype = npy_cfloat,npy_cdouble,npy_clongdouble#
12+
* #c = f, , l#
13+
*/
14+
15+
NPY_NO_EXPORT void
16+
nc_log1p@c@(@ctype@ *x, @ctype@ *r)
17+
{
18+
std::complex<@fptype@> z(npy_creal@c@(*x), npy_cimag@c@(*x));
19+
auto w = log1p_complex::log1p_complex(z);
20+
npy_csetreal@c@(r, w.real());
21+
npy_csetimag@c@(r, w.imag());
22+
}
23+
24+
/**end repeat**/
25+
26+
27+
} // extern "C"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef LOG1P_COMPLEX_WRAPPERS_H
2+
#define LOG1P_COMPLEX_WRAPPERS_H
3+
4+
// This header is to be included in umathmodule.c,
5+
// so it will be processed by a C compiler.
6+
//
7+
// This file assumes that the numpy header files have
8+
// already been included.
9+
10+
NPY_NO_EXPORT void
11+
nc_log1pf(npy_cfloat *x, npy_cfloat *r);
12+
13+
NPY_NO_EXPORT void
14+
nc_log1p(npy_cdouble *x, npy_cdouble *r);
15+
16+
NPY_NO_EXPORT void
17+
nc_log1pl(npy_clongdouble *x, npy_clongdouble *r);
18+
19+
#endif // LOG1P_COMPLEX_WRAPPERS_H

numpy/_core/src/umath/umathmodule.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "string_ufuncs.h"
3131
#include "stringdtype_ufuncs.h"
3232
#include "special_integer_comparisons.h"
33+
#include "log1p_complex_wrappers.h"
3334
#include "extobj.h" /* for _extobject_contextvar exposure */
3435

3536
/* Automatically generated code to define all ufuncs: */

0 commit comments

Comments
 (0)
0