8000 MAINT: random: Rewrite the hypergeometric distribution. · numpy/numpy@b2d2b67 · GitHub
[go: up one dir, main page]

Skip to content

Commit b2d2b67

Browse files
MAINT: random: Rewrite the hypergeometric distribution.
Summary of the changes: * Move the functions random_hypergeometric_hyp, random_hypergeometric_hrua and random_hypergeometric from distributions.c to legacy-distributions.c. These are now the legacy implementation of hypergeometric. * Add the files logfactorial.c and logfactorial.h, containing the function logfactorial(int64_t k). * Add the files random_hypergeometric.c and random_hypergeometric.h, containing the function random_hypergeometric (the new implementation of the hypergeometric distribution). See more details below. * Fix two tests in numpy/random/tests/test_generator_mt19937.py that used values returned by the hypergeometric distribution. The new implementation changes the stream, so those tests needed to be updated. * Remove another test obviated by an added constraint on the arguments of hypergeometric. Details of the rewrite: If you carefully step through the old function rk_hypergeometric_hyp(), you'll see that the end result is basically the same as the new function hypergeometric_sample(), but the new function accomplishes the result with jus 8000 t integers. The floating point calculations in the old code caused problems when the arguments were extremely large (explained in more detail in the unmerged pull request #9834). The new version of hypergeometric_hrua() is a new translation of Stadlober's ratio-of-uniforms algorithm for the hypergeometric distribution. It fixes a mistake in the old implementation that made the method less efficient than it could be (see the details in the unmerged pull request #11138), and uses a faster function for computing log(k!). The HRUA algorithm suffers from loss of floating point precision when the arguments are *extremely* large (see the comments in github issue 11443). To avoid these problems, the arguments `ngood` and `nbad` of hypergeometric must be less than 10**9. This constraint obviates an existing regression test that was run on systems with 64 bit long integers, so that test was removed.
1 parent e3eb398 commit b2d2b67

10 files changed

+592
-124
lines changed

numpy/random/generator.pyx

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3095,9 +3095,11 @@ cdef class Generator:
30953095
Parameters
30963096
----------
30973097
ngood : int or array_like of ints
3098-
Number of ways to make a good selection. Must be nonnegative.
3098+
Number of ways to make a good selection. Must be nonnegative and
3099+
less than 10**9.
30993100
nbad : int or array_like of ints
3100-
Number of ways to make a bad selection. Must be nonnegative.
3101+
Number of ways to make a bad selection. Must be nonnegative and
3102+
less than 10**9.
31013103
nsample : int or array_like of ints
31023104
Number of items sampled. Must be nonnegative and less than
31033105
``ngood + nbad``.
@@ -3142,6 +3144,13 @@ cdef class Generator:
31423144
replacement (or the sample space is infinite). As the sample space
31433145
becomes large, this distribution approaches the binomial.
31443146
3147+
The arguments `ngood` and `nbad` each must be less than `10**9`. For
3148+
extremely large arguments, the algorithm that is used to compute the
3149+
samples [4]_ breaks down because of loss of precision in floating point
3150+
calculations. For such large values, if `nsample` is not also large,
3151+
the distribution can be approximated with the binomial distribution,
3152+
`binomial(n=nsample, p=ngood/(ngood + nbad))`.
3153+
31453154
References
31463155
----------
31473156
.. [1] Lentner, Marvin, "Elementary Applied Statistics", Bogden
@@ -3151,6 +3160,9 @@ cdef class Generator:
31513160
http://mathworld.wolfram.com/HypergeometricDistribution.html
31523161
.. [3] Wikipedia, "Hypergeometric distribution",
31533162
https://en.wikipedia.org/wiki/Hypergeometric_distribution
3163+
.. [4] Stadlober, Ernst, "The ratio of uniforms approach for generating
3164+
discrete random variates", Journal of Computational and Applied
3165+
Mathematics, 31, pp. 181-189 (1990).
31543166
31553167
Examples
31563168
--------
@@ -3172,6 +3184,7 @@ cdef class Generator:
31723184
# answer = 0.003 ... pretty unlikely!
31733185
31743186
"""
3187+
DEF HYPERGEOM_MAX = 10**9
31753188
cdef bint is_scalar = True
31763189
cdef np.ndarray ongood, onbad, onsample
31773190
cdef int64_t lngood, lnbad, lnsample
@@ -3186,15 +3199,23 @@ cdef class Generator:
31863199
lnbad = <int64_t>nbad
31873200
lnsample = <int64_t>nsample
31883201

3202+
if lngood >= HYPERGEOM_MAX or lnbad >= HYPERGEOM_MAX:
3203+
raise ValueError("both ngood and nbad must be less than %d" %
3204+
HYPERGEOM_MAX)
31893205
if lngood + lnbad < lnsample:
31903206
raise ValueError("ngood + nbad < nsample")
31913207
return disc(&random_hypergeometric, &self._bitgen, size, self.lock, 0, 3,
31923208
lngood, 'ngood', CONS_NON_NEGATIVE,
31933209
lnbad, 'nbad', CONS_NON_NEGATIVE,
31943210
lnsample, 'nsample', CONS_NON_NEGATIVE)
31953211

3212+
if np.any(ongood >= HYPERGEOM_MAX) or np.any(onbad >= HYPERGEOM_MAX):
3213+
raise ValueError("both ngood and nbad must be less than %d" %
3214+
HYPERGEOM_MAX)
3215+
31963216
if np.any(np.less(np.add(ongood, onbad), onsample)):
31973217
raise ValueError("ngood + nbad < nsample")
3218+
31983219
return discrete_broadcast_iii(&random_hypergeometric, &self._bitgen, size, self.lock,
31993220
ongood, 'ngood', CONS_NON_NEGATIVE,
32003221
onbad, 'nbad', CONS_NON_NEGATIVE,

numpy/random/setup.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,15 @@ def generate_libraries(ext, build_dir):
126126
depends=['%s.pyx' % gen],
127127
define_macros=defs,
128128
)
129+
other_srcs = [
130+
'src/distributions/logfactorial.c',
131+
'src/distributions/distributions.c',
132+
'src/distributions/random_hypergeometric.c',
133+
]
129134
for gen in ['generator', 'bounded_integers']:
130135
# gen.pyx, src/distributions/distributions.c
131136
config.add_extension(gen,
132-
sources=['{0}.c'.format(gen),
133-
join('src', 'distributions',
134-
'distributions.c')],
137+
sources=['{0}.c'.format(gen)] + other_srcs,
135138
libraries=EXTRA_LIBRARIES,
136139
extra_compile_args=EXTRA_COMPILE_ARGS,
137140
include_dirs=['.', 'src'],
@@ -140,8 +143,10 @@ def generate_libraries(ext, build_dir):
140143
define_macros=defs,
141144
)
142145
config.add_extension('mtrand',
146+
# mtrand does not depend on random_hypergeometric.c.
143147
sources=['mtrand.c',
144148
'src/legacy/legacy-distributions.c',
149+
'src/distributions/logfactorial.c',
145150
'src/distributions/distributions.c'],
146151
include_dirs=['.', 'src', 'src/legacy'],
147152
libraries=EXTRA_LIBRARIES,

numpy/random/src/distributions/distributions.c

Lines changed: 5 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "distributions.h"
22
#include "ziggurat_constants.h"
3+
#include "logfactorial.h"
34

45
#if defined(_MSC_VER) && defined(_WIN64)
56
#include <intrin.h>
@@ -468,8 +469,11 @@ uint64_t random_uint(bitgen_t *bitgen_state) {
468469
* log-gamma function to support some of these distributions. The
469470
* algorithm comes from SPECFUN by Shanjie Zhang and Jianming Jin and their
470471
* book "Computation of Special Functions", 1996, John Wiley & Sons, Inc.
472+
*
473+
* If loggam(k+1) is being used to compute log(k!) for an integer k, consider
474+
* using logfactorial(k) instead.
471475
*/
472-
static double loggam(double x) {
476+
double loggam(double x) {
473477
double x0, x2, xp, gl, gl0;
474478
RAND_INT_TYPE k, n;
475479

@@ -1127,105 +1131,6 @@ double random_triangular(bitgen_t *bitgen_state, double left, double mode,
11271131
}
11281132
}
11291133

1130-
RAND_INT_TYPE random_hypergeometric_hyp(bitgen_t *bitgen_state,
1131-
RAND_INT_TYPE good, RAND_INT_TYPE bad,
1132-
RAND_INT_TYPE sample) {
1133-
RAND_INT_TYPE d1, k, z;
1134-
double d2, u, y;
1135-
1136-
d1 = bad + good - sample;
1137-
d2 = (double)MIN(bad, good);
1138-
1139-
y = d2;
1140-
k = sample;
1141-
while (y > 0.0) {
1142-
u = next_double(bitgen_state);
1143-
y -= (RAND_INT_TYPE)floor(u + y / (d1 + k));
1144-
k--;
1145-
if (k == 0)
1146-
break;
1147-
}
1148-
z = (RAND_INT_TYPE)(d2 - y);
1149-
if (good > bad)
1150-
z = sample - z;
1151-
return z;
1152-
}
1153-
1154-
/* D1 = 2*sqrt(2/e) */
1155-
/* D2 = 3 - 2*sqrt(3/e) */
1156-
#define D1 1.7155277699214135
1157-
#define D2 0.8989161620588988
1158-
RAND_INT_TYPE random_hypergeometric_hrua(bitgen_t *bitgen_state,
1159-
RAND_INT_TYPE good, RAND_INT_TYPE bad,
1160-
RAND_INT_TYPE sample) {
1161-
RAND_INT_TYPE mingoodbad, maxgoodbad, popsize, m, d9;
1162-
double d4, d5, d6, d7, d8, d10, d11;
1163-
RAND_INT_TYPE Z;
1164-
double T, W, X, Y;
1165-
1166-
mingoodbad = MIN(good, bad);
1167-
popsize = good + bad;
1168-
maxgoodbad = MAX(good, bad);
1169-
m = MIN(sample, popsize - sample);
1170-
d4 = ((double)mingoodbad) / popsize;
1171-
d5 = 1.0 - d4;
1172-
d6 = m * d4 + 0.5;
1173-
d7 = sqrt((double)(popsize - m) * sample * d4 * d5 / (popsize - 1) + 0.5);
1174-
d8 = D1 * d7 + D2;
1175-
d9 = (RAND_INT_TYPE)floor((double)(m + 1) * (mingoodbad + 1) / (popsize + 2));
1176-
d10 = (loggam(d9 + 1) + loggam(mingoodbad - d9 + 1) + loggam(m - d9 + 1) +
1177-
loggam(maxgoodbad - m + d9 + 1));
1178-
d11 = MIN(MIN(m, mingoodbad) + 1.0, floor(d6 + 16 * d7));
1179-
/* 16 for 16-decimal-digit precision in D1 and D2 */
1180-
1181-
while (1) {
1182-
X = next_double(bitgen_state);
1183-
Y = next_double(bitgen_state);
1184-
W = d6 + d8 * (Y - 0.5) / X;
1185-
1186-
/* fast rejection: */
1187-
if ((W < 0.0) || (W >= d11))
1188-
continue;
1189-
1190-
Z = (RAND_INT_TYPE)floor(W);
1191-
T = d10 - (loggam(Z + 1) + loggam(mingoodbad - Z + 1) + loggam(m - Z + 1) +
1192-
loggam(maxgoodbad - m + Z + 1));
1193-
1194-
/* fast acceptance: */
1195-
if ((X * (4.0 - X) - 3.0) <= T)
1196-
break;
1197-
1198-
/* fast rejection: */
1199-
if (X * (X - T) >= 1)
1200-
continue;
1201-
/* log(0.0) is ok here, since always accept */
1202-
if (2.0 * log(X) <= T)
1203-
break; /* acceptance */
1204-
}
1205-
1206-
/* this is a correction to HRUA* by Ivan Frohne in rv.py */
1207-
if (good > bad)
1208-
Z = m - Z;
1209-
1210-
/* another fix from rv.py to allow sample to exceed popsize/2 */
1211-
if (m < sample)
1212-
Z = good - Z;
1213-
1214-
return Z;
1215-
}
1216-
#undef D1
1217-
#undef D2
1218-
1219-
RAND_INT_TYPE random_hypergeometric(bitgen_t *bitgen_state, RAND_INT_TYPE good,
1220-
RAND_INT_TYPE bad, RAND_INT_TYPE sample) {
1221-
if (sample > 10) {
1222-
return random_hypergeometric_hrua(bitgen_state, good, bad, sample);
1223-
} else if (sample > 0) {
1224-
return random_hypergeometric_hyp(bitgen_state, good, bad, sample);
1225-
} else {
1226-
return 0;
1227-
}
1228-
}
12291134

12301135
uint64_t random_interval(bitgen_t *bitgen_state, uint64_t max) {
12311136
uint64_t mask, value;

numpy/random/src/distributions/distributions.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ static NPY_INLINE double next_double(bitgen_t *bitgen_state) {
8484
return bitgen_state->next_double(bitgen_state->state);
8585
}
8686

87+
DECLDIR double loggam(double x);
88+
8789
DECLDIR float random_float(bitgen_t *bitgen_state);
8890
DECLDIR double random_double(bitgen_t *bitgen_state);
8991
DECLDIR void random_double_fill(bitgen_t *bitgen_state, npy_intp cnt, double *out);
@@ -160,8 +162,8 @@ DECLDIR RAND_INT_TYPE random_geometric_search(bitgen_t *bitgen_state, double p);
160162
DECLDIR RAND_INT_TYPE random_geometric_inversion(bitgen_t *bitgen_state, double p);
161163
DECLDIR RAND_INT_TYPE random_geometric(bitgen_t *bitgen_state, double p);
162164
DECLDIR RAND_INT_TYPE random_zipf(bitgen_t *bitgen_state, double a);
163-
DECLDIR RAND_INT_TYPE random_hypergeometric(bitgen_t *bitgen_state, RAND_INT_TYPE good,
164-
RAND_INT_TYPE bad, RAND_INT_TYPE sample);
165+
DECLDIR int64_t random_hypergeometric(bitgen_t *bitgen_state,
166+
int64_t good, int64_t bad, int64_t sample);
165167

166168
DECLDIR uint64_t random_interval(bitgen_t *bitgen_state, uint64_t max);
167169

0 commit comments

Comments
 (0)
0