8000 BUG: Make floating remainder ufunc more exact. · charris/numpy@cb191fe · GitHub
[go: up one dir, main page]

Skip to content

Commit cb191fe

Browse files
committed
BUG: Make floating remainder ufunc more exact.
The intent here is to make sure the following is true for floating numbers x (dividend) and y (divisor) and r = remainder(x, y). * If both x and y are small integer floats, r is exact. * The sign of r is the same as the sign of y, including signed zero. * The magnitude of r is strictly less than the magnitude of y. * y ~= r + y*floor(x/y), i.e., r is compatible with floor. Remainder functions are also added to npy_math for all supported floats 'efdg'. This helps keep scalar and array results in sync. Explicit loops are also made for remainder as the speedup over using generic loops is about 20%. Note that the NumPy version of remainder differs from that in Python, as the latter is based around the fmod function rather than floor. Closes numpy#7224.
1 parent 920c527 commit cb191fe

File tree

7 files changed

+89
-14
lines changed

7 files changed

+89
-14
lines changed

numpy/core/include/numpy/halffloat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ int npy_half_signbit(npy_half h);
3737
npy_half npy_half_copysign(npy_half x, npy_half y);
3838
npy_half npy_half_spacing(npy_half h);
3939
npy_half npy_half_nextafter(npy_half x, npy_half y);
40+
npy_half npy_half_remainder(npy_half x, npy_half y);
4041

4142
/*
4243
* Half-precision constants

numpy/core/include/numpy/npy_math.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,16 +309,19 @@ double npy_deg2rad(double x);
309309
double npy_rad2deg(double x);
310310
double npy_logaddexp(double x, double y);
311311
double npy_logaddexp2(double x, double y);
312+
double npy_remainder(double x, double y);
312313

313314
float npy_deg2radf(float x);
314315
float npy_rad2degf(float x);
315316
float npy_logaddexpf(float x, float y);
316317
float npy_logaddexp2f(float x, float y);
318+
float npy_remainderf(float x, float y);
317319

318320
npy_longdouble npy_deg2radl(npy_longdouble x);
319321
npy_longdouble npy_rad2degl(npy_longdouble x);
320322
npy_longdouble npy_logaddexpl(npy_longdouble x, npy_longdouble y);
321323
npy_longdouble npy_logaddexp2l(npy_longdouble x, npy_longdouble y);
324+
npy_longdouble npy_remainderl(npy_longdouble x, npy_longdouble y);
322325

323326
#define npy_degrees npy_rad2deg
324327
#define npy_degreesf npy_rad2degf

numpy/core/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ def generate_umath_c(ext, build_dir):
896896

897897
umath_deps = [
898898
generate_umath_py,
899+
join('include', 'numpy', 'npy_math.h'),
899900
join('src', 'multiarray', 'common.h'),
900901
join('src', 'private', 'templ_common.h.src'),
901902
join('src', 'umath', 'simd.inc.src'),

numpy/core/src/npymath/halffloat.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,25 @@ int npy_half_signbit(npy_half h)
7272
return (h&0x8000u) != 0;
7373
}
7474

75+
npy_half npy_half_remainder(npy_half x, npy_half y)
76+
{
77+
const npy_half half_zero = (npy_half)0;
78+
const float xf = npy_half_to_float(x);
79+
const float yf = npy_half_to_float(y);
80+
float remf;
81+
npy_half remh;
82+
83+
remh = npy_float_to_half(npy_remainderf(xf, yf));
84+
remf = npy_half_to_float(remh);
85+
if (yf > 0 && remf >= yf) {
86+
remh = npy_half_nextafter(remh, half_zero);
87+
}
88+
if (yf < 0 && remf <= yf) {
89+
remh = npy_half_nextafter(remh, half_zero);
90+
}
91+
return remh;
92+
}
93+
7594
npy_half npy_half_spacing(npy_half h)
7695
{
7796
npy_half ret;

numpy/core/src/npymath/npy_math.c.src

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,36 @@ double npy_log2(double x)
608608
}
609609
}
610610

611+
612+
/* remainder(x, y)
613+
*
614+
* Unlike Python, we assume that the floor function is sacred rather
615+
* than fmod. The result is guaranteed to have the same sign as the
616+
* divisor and abs(remainder) < abs(y).
617+
*/
618+
@type@ npy_remainder@c@(@type@ x, @type@ y)
619+
{
620+
@type@ rem = x - y*npy_floor@c@(x/y);
621+
622+
if (y < 0) {
623+
if (rem >= 0) {
624+
rem = -0.0@c@;
625+
}
626+
else if (rem <= y) {
627+
rem = npy_nextafter@c@(y, 0);
628+
}
629+
}
630+
else if (y > 0) {
631+
if (rem <= 0) {
632+
rem = 0.0@c@;
633+
}
634+
else if (rem >= y) {
635+
rem = npy_nextafter@c@(y, 0);
636+
}
637+
}
638+
return rem;
639+
}
640+
611641
#undef LOGE2
612642
#undef LOG2E
613643
#undef RAD2DEG

numpy/core/src/umath/loops.c.src

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,8 +1706,26 @@ NPY_NO_EXPORT void
17061706
BINARY_LOOP {
17071707
const @type@ in1 = *(@type@ *)ip1;
17081708
const @type@ in2 = *(@type@ *)ip2;
1709-
const @type@ div = in1/in2;
1710-
*((@type@ *)op1) = in2*(div - npy_floor@c@(div));
1709+
@type@ rem;
1710+
1711+
rem = in1 - in2*npy_floor@c@(in1/in2);
1712+
if (in2 < 0) {
1713+
if (rem >= 0) {
1714+
rem = -0.0@c@;
1715+
}
1716+
else if (rem <= in2) {
1717+
rem = npy_nextafter@c@(in2, 0);
1718+
}
1719+
}
1720+
else if (in2 > 0) {
1721+
if (rem <= 0) {
1722+
rem = 0.0@c@;
1723+
}
1724+
else if (rem >= in2) {
1725+
rem = npy_nextafter@c@(in2, 0);
1726+
}
1727+
}
1728+
*((@type@ *)op1) = rem;
17111729
}
17121730
}
17131731

@@ -2023,13 +2041,18 @@ HALF_remainder(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNU
20232041
BINARY_LOOP {
20242042
const float in1 = npy_half_to_float(*(npy_half *)ip1);
20252043
const float in2 = npy_half_to_float(*(npy_half *)ip2);
2026-
const float res = npy_fmodf(in1,in2);
2027-
if (res && ((in2 < 0) != (res < 0))) {
2028-
*((npy_half *)op1) = npy_float_to_half(res + in2);
2044+
float remf;
2045+
npy_half remh;
2046+
2047+
remh = npy_float_to_half(npy_remainderf(in1, in2));
2048+
remf = npy_half_to_float(remh);
2049+
if (in2 > 0 && remf >= in2) {
2050+
remh = npy_half_nextafter(remh, NPY_HALF_ZERO);
20292051
}
2030-
else {
2031-
*((npy_half *)op1) = npy_float_to_half(res);
2052+
if (in2 < 0 && remf <= in2) {
2053+
remh = npy_half_nextafter(remh, NPY_HALF_ZERO);
20322054
}
2055+
*((npy_half *)op1) = remh;
20332056
}
20342057
}
20352058

numpy/core/src/umath/scalarmath.c.src

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "npy_pycompat.h"
2222

2323
#include "numpy/halffloat.h"
24+
#include "numpy/npy_math.h"
2425
#include "templ_common.h"
2526

2627
/* Basic operations:
@@ -283,7 +284,6 @@ static @type@ (*_basic_@name@_fmod)(@type@, @type@);
283284

284285
static npy_half (*_basic_half_floor)(npy_half);
285286
static npy_half (*_basic_half_sqrt)(npy_half);
286-
static npy_half (*_basic_half_fmod)(npy_half, npy_half);
287287

288288
#define half_ctype_add(a, b, outp) *(outp) = \
289289
npy_float_to_half(npy_half_to_float(a) + npy_half_to_float(b))
@@ -353,22 +353,21 @@ static npy_half (*_basic_half_fmod)(npy_half, npy_half);
353353
} while(0)
354354
/**end repeat**/
355355

356+
356357
/**begin repeat
357358
* #name = float, double, longdouble#
358359
* #type = npy_float, npy_double, npy_longdouble#
360+
* #c = f, ,l#
359361
*/
360362
static void
361363
@name@_ctype_remainder(@type@ a, @type@ b, @type@ *out) {
362-
@type@ tmp = a/b;
363-
*out = b * (tmp - _basic_@name@_floor(tmp));
364+
*out = npy_remainder@c@(a, b);
364365
}
365366
/**end repeat**/
366367

367368
static void
368369
half_ctype_remainder(npy_half a, npy_half b, npy_half *out) {
369-
float tmp, fa = npy_half_to_float(a), fb = npy_half_to_float(b);
370-
float_ctype_remainder(fa, fb, &tmp);
371-
*out = npy_float_to_half(tmp);
370+
*out = npy_half_remainder(a, b);
372371
}
373372

374373

@@ -1721,7 +1720,6 @@ get_functions(PyObject * mm)
17211720
i += 3;
17221721
j++;
17231722
}
1724-
_basic_half_fmod = funcdata[j - 1];
17251723
_basic_float_fmod = funcdata[j];
17261724
_basic_double_fmod = funcdata[j + 1];
17271725
_basic_longdouble_fmod = funcdata[j + 2];

0 commit comments

Comments
 (0)
0