8000 Rewrite complex log1p to improve precision. · WarrenWeckesser/numpy@1bebfe8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1bebfe8

Browse files
Rewrite complex log1p to improve precision.
Closes numpygh-22609
1 parent 7687245 commit 1bebfe8

File tree

7 files changed

+508
-15
lines changed

7 files changed

+508
-15
lines changed

numpy/_core/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,7 @@ src_umath = umath_gen_headers + [
11541154
src_file.process('src/umath/scalarmath.c.src'),
11551155
'src/umath/ufunc_object.c',
11561156
'src/umath/umathmodule.c',
1157+
src_file.process('src/umath/log1p_complex_wrappers.cpp.src'),
11571158
'src/umath/special_integer_comparisons.cpp',
11581159
'src/umath/string_ufuncs.cpp',
11591160
'src/umath/stringdtype_ufuncs.cpp',

numpy/_core/src/umath/funcs.inc.src

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -307,15 +307,6 @@ nc_log@c@(@ctype@ *x, @ctype@ *r)
307307
return;
308308
}
309309

310-
static void
311-
nc_log1p@c@(@ctype@ *x, @ctype@ *r)
312-
{
313-
@ftype@ l = npy_hypot@c@(npy_creal@c@(*x) + 1,npy_cimag@c@(*x));
314-
npy_csetimag@c@(r, npy_atan2@c@(npy_cimag@c@(*x), npy_creal@c@(*x) + 1));
315-
npy_csetreal@c@(r, npy_log@c@(l));
316-
return;
317-
}
318-
319310
static void
320311
nc_exp@c@(@ctype@ *x, @ctype@ *r)
321312
{

numpy/_core/src/umath/log1p_complex.h

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
#ifndef LOG1P_COMPLEX_H
2+
#define LOG1P_COMPLEX_H
3+
4+
#include <Python.h>
5+
#include "numpy/ndarraytypes.h"
6+
#include "numpy/npy_math.h"
7+
8+
#include <cmath>
9+
#include <complex>
10+
#include <limits>
11+
12+
// For memcpy
13+
#include <cstring>
14+
15+
16+
//
17+
// Trivial C++ wrappers for several npy_* functions.
18+
//
19+
20+
#define CPP_WRAP1(name) \
21+
inline float name(const float x) \
22+
{ \
23+
return ::npy_ ## name ## f(x); \
24+
} \
25+
inline double name(const double x) \
26+
{ \
27+
return ::npy_ ## name(x); \
28+
} \
29+
inline long double name(const long double x) \
30+
{ \
31+
return ::npy_ ## name ## l(x); \
32+
} \
33+
34+
#define CPP_WRAP2(name) \
35+
inline float name(const float x, \
36+
const float y) \
37+
{ \
38+
return ::npy_ ## name ## f(x, y); \
39+
} \
40+
inline double name(const double x, \
41+
const double y) \
42+
{ \
43+
return ::npy_ ## name(x, y); \
44+
} \
45+
inline long double name(const long double x, \
46+
const long double y) \
47+
{ \
48+
return ::npy_ ## name ## l(x, y); \
49+
} \
50+
51+
52+
namespace npy {
53+
54+
CPP_WRAP1(fabs)
55+
CPP_WRAP1(log)
56+
CPP_WRAP1(log1p)
57+
CPP_WRAP2(atan2)
58+
CPP_WRAP2(hypot)
59+
60+
}
61+
62+
namespace log1p_complex
63+
{
64+
65+
template<typename T>
66+
struct doubled_t {
67+
T upper;
68+
T lower;
69+
};
70+
71+
//
72+
// There are three functions below where it is crucial that the
73+
// expressions are not optimized. E.g. `t - (t - x)` must not be
74+
// simplified by the compiler to just `x`. The NO_OPT macro defines
75+
// an attribute that should turn off optimization for the function.
76+
//
77+
// The inclusion of `gnu::target("fpmath=sse")` when __GNUC__ and
78+
// __i386 are defined also turns off the use of the floating-point
79+
// unit '387'. It is important that when the type is, for example,
80+
// `double`, these functions compute their results with 64 bit
81+
// precision, and not with 80 bit extended precision.
82+
//
83+
84+
#if defined(__clang__)
85+
#define NO_OPT [[clang::optnone]]
86+
#elif defined(__GNUC__)
87+
#if defined(__i386)
88+
#define NO_OPT [[gnu::optimize(0),gnu::target("fpmath=sse")]]
89+
#else
90+
#define NO_OPT [[gnu::optimize(0)]]
91+
#endif
92+
#else
93+
#define NO_OPT
94+
#endif
95+
96+
//
97+
// Dekker splitting. See, for example, Theorem 1 of
98+
//
99+
// Seppa Linnainmaa, Software for Double-Precision Floating-Point
100+
// Computations, ACM Transactions on Mathematical Software, Vol 7, No 3,
101+
// September 1981, pages 272-283.
102+
//
103+
// or Theorem 17 of
104+
//
105+
// J. R. Shewchuk, Adaptive Precision Floating-Point Arithmetic and
106+
// Fast Robust Geometric Predicates, CMU-CS-96-140R, from Discrete &
107+
// Computational Geometry 18(3):305-363, October 1997.
108+
//
109+
template<typename T>
110+
NO_OPT inline void
111+
split(T x, doubled_t<T>& out)
112+
{
113+
if (std::numeric_limits<T>::digits == 106) {
114+
// Special case: IBM double-double format. The value is already
115+
// split in memory, so there is no need for any calculations.
116+
std::memcpy(&out, &x, sizeof(out));
117+
}
118+
else {
119+
constexpr int halfprec = (std::numeric_limits<T>::digits + 1)/2;
120+
T t = ((1ull << halfprec) + 1)*x;
121+
// The compiler must not be allowed to simplify this expression:
122+
out.upper = t - (t - x);
123+
out.lower = x - out.upper;
124+
}
125+
}
126+
127+
template<typename T>
128+
NO_OPT inline void
129+
two_sum_quick(T x, T y, doubled_t<T>& out)
130+
{
131+
T r = x + y;
132+
T e = y - (r - x);
133+
out.upper = r;
134+
out.lower = e;
135+
}
136+
137+
template<typename T>
138+
NO_OPT inline void
139+
two_sum(T x, T y, doubled_t<T>& out)
140+
{
141+
T s = x + y;
142+
T v = s - x;
143+
T e = (x - (s - v)) + (y - v);
144+
out.upper = s;
145+
out.lower = e;
146+
}
147+
148+
template<typename T>
149+
inline void
150+
double_sum(const doubled_t<T>& x, const doubled_t<T>& y,
151+
doubled_t<T>& out)
152+
{
153+
two_sum<T>(x.upper, y.upper, out);
154+
out.lower += x.lower + y.lower;
155+
two_sum_quick<T>(out.upper, out.lower, out);
156+
}
157+
158+
template<typename T>
159+
inline void
160+
square(T x, doubled_t<T>& out)
161+
{
162+
doubled_t<T> xsplit;
163+
out.upper = x*x;
164+
split(x, xsplit);
165+
out.lower = xsplit.lower*xsplit.lower
166+
- ((out.upper - xsplit.upper*xsplit.upper)
167+
- 2*xsplit.lower*xsplit.upper);
168+
}
169+
170+
//
171+
// As the name makes clear, this function computes x**2 + 2*x + y**2.
172+
// It uses doubled_t<T> for the intermediate calculations.
173+
// (That is, we give the floating point type T an upgrayedd, spelled with
174+
// two d's for a double dose of precision.)
175+
//
176+
// The function is used in log1p_complex() to avoid the loss of
177+
// precision that can occur in the expression when x**2 + y**2 ≈ -2*x.
178+
//
179+
template<typename T>
180+
inline T
181+
xsquared_plus_2x_plus_ysquared_dd(T x, T y)
182+
{
183+
doubled_t<T> x2, y2, twox, sum1, sum2;
184+
185+
square<T>(x, x2); // x2 = x**2
186+
square<T>(y, y2); // y2 = y**2
187+
twox.upper = 2*x; // twox = 2*x
188+
twox.lower = 0.0;
189+
double_sum<T>(x2, twox, sum1); // sum1 = x**2 + 2*x
190+
double_sum<T>(sum1, y2, sum2); // sum2 = x**2 + 2*x + y**2
191+
return sum2.upper;
192+
}
193+
194+
//
195+
// For the float type, the intermediate calculation is done
196+
// with the double type. We don't need to use doubled_t<float>.
197+
//
198+
inline float
199+
xsquared_plus_2x_plus_ysquared(float x, float y)
200+
{
201+
double xd = x;
202+
double yd = y;
203+
return xd*(2.0 + xd) + yd*yd;
204+
}
205+
206+
//
207+
// For double, we used doubled_t<double> if long double doesn't have
208+
// at least 106 bits of precision.
209+
//
210+
inline double
211+
xsquared_plus_2x_plus_ysquared(double x, double y)
212+
{
213+
if (std::numeric_limits<long double>::digits >= 106) {
214+
// Cast to long double for the calculation.
215+
long double xd = x;
216+
long double yd = y;
217+
return xd*(2.0L + xd) + yd*yd;
218+
}
219+
else {
220+
// Use doubled_t<double> for the calculation.
221+
return xsquared_plus_2x_plus_ysquared_dd<double>(x, y);
222+
}
223+
}
224+
225+
//
226+
// For long double, we always use doubled_t<long double> for the
227+
// calculation.
228+
//
229+
inline long double
230+
xsquared_plus_2x_plus_ysquared(long double x, long double y)
231+
{
232+
return xsquared_plus_2x_plus_ysquared_dd<long double>(x, y);
233+
}
234+
235+
//
236+
// Implement log1p(z) for complex inputs, using higher precision near
237+
// the unit circle |z + 1| = 1 to avoid loss of precision that can occur
238+
// when x**2 + y**2 ≈ -2*x.
239+
//
240+
// If the real or imaginary part of z is NAN, complex(NAN, NAN) is returned.
241+
//
242+
template<typename T>
243+
inline std::complex<T>
244+
log1p_complex(std::complex<T> z)
245+
{
246+
T lnr;
247+
T x = z.real();
248+
T y = z.imag();
249+
if (std::isnan(x) || std::isnan(y)) {
250+
return std::complex<T>(NPY_NAN, NPY_NAN);
251+
}
252+
if (x > -2.2 && x < 0.2 && y > -1.2 && y < 1.2
253+
&& npy::fabs(x*(2 + x) + y*y) < 0.4) {
254+
// The input is close to the unit circle centered at -1+0j.
255+
// Compute x**2 + 2*x + y**2 with higher precision than T.
256+
// The calculation here is equivalent to log(hypot(x+1, y)),
257+
// since
258+
// log(hypot(x+1, y)) = 0.5*log(x**2 + 2*x + 1 + y**2)
259+
// = 0.5*log1p(x**2 + 2*x + y**2)
260+
T t = xsquared_plus_2x_plus_ysquared(x, y);
261+
lnr = 0.5*npy::log1p(t);
262+
}
263+
else {
264+
lnr = npy::log(npy::hypot(x + static_cast<T>(1), y));
265+
}
266+
return std::complex<T>(lnr, npy::atan2(y, x + static_cast<T>(1)));
267+
}
268+
269+
} // namespace log1p_complex
270+
271+
#endif
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include <Python.h>
2+
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
3+
#include "numpy/npy_math.h"
4+
#include "numpy/ndarraytypes.h"
5+
#include <complex>
6+
#include "log1p_complex.h"
7+
8+
extern "C" {
9+
10+
/**begin repeat
11+
* #fptype = float, double, long double#
12+
* #ctype = npy_cfloat,npy_cdouble,npy_clongdouble#
13+
* #c = f, , l#
14+
*/
15+
16+
NPY_NO_EXPORT void
17+
nc_log1p@c@(@ctype@ *x, @ctype@ *r)
18+
{
19+
std::complex<@fptype@> z(npy_creal@c@(*x), npy_cimag@c@(*x));
20+
auto w = log1p_complex::log1p_complex(z);
21+
npy_csetreal@c@(r, w.real());
22+
npy_csetimag@c@(r, w.imag());
23+
}
24+
25+
/**end repeat**/
26+
27+
28+
} // extern "C"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef LOG1P_COMPLEX_WRAPPERS_H
2+
#define LOG1P_COMPLEX_WRAPPERS_H
3+
4+
// This header is to be included in umathmodule.c,
5+
// so it will be processed by a C compiler.
6+
//
7+
// This file assumes that the numpy header files have
8+
// already been included.
9+
10+
NPY_NO_EXPORT void
11+
nc_log1pf(npy_cfloat *x, npy_cfloat *r);
12+
13+
NPY_NO_EXPORT void
14+
nc_log1p(npy_cdouble *x, npy_cdouble *r);
15+
16+
NPY_NO_EXPORT void
17+
nc_log1pl(npy_clongdouble *x, npy_clongdouble *r);
18+
19+
#endif // LOG1P_COMPLEX_WRAPPERS_H

numpy/_core/src/umath/umathmodule.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "string_ufuncs.h"
3131
#include "stringdtype_ufuncs.h"
3232
#include "special_integer_comparisons.h"
33+
#include "log1p_complex_wrappers.h"
3334
#include "extobj.h" /* for _extobject_contextvar exposure */
3435
#include "ufunc_type_resolution.h"
3536

0 commit comments

Comments
 (0)
0