8000 Merge pull request #13794 from WarrenWeckesser/new-mvhg · numpy/numpy@2aa3bba · GitHub
[go: up one dir, main page]

Skip to content

Commit 2aa3bba

Browse files
authored
Merge pull request #13794 from WarrenWeckesser/new-mvhg
ENH: random: Add the multivariate hypergeometric distribution.
2 parents 9c8e904 + 0455447 commit 2aa3bba

File tree

8 files changed

+674
-3
lines changed

8 files changed

+674
-3
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Multivariate hypergeometric distribution added to `numpy.random`
2+
----------------------------------------------------------------
3+
The method `multivariate_hypergeometric` has been added to the class
4+
`numpy.random.Generator`. This method generates random variates from
5+
the multivariate hypergeometric probability distribution.

doc/source/reference/random/generator.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Distributions
6262
~numpy.random.Generator.lognormal
6363
~numpy.random.Generator.logseries
6464
~numpy.random.Generator.multinomial
65+
~numpy.random.Generator.multivariate_hypergeometric
6566
~numpy.random.Generator.multivariate_normal
6667
~numpy.random.Generator.negative_binomial
6768
~numpy.random.Generator.noncentral_chisquare

numpy/random/_generator.pyx

Lines changed: 248 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ from numpy.core.multiarray import normalize_axis_index
1313

1414
from libc cimport string
1515
from libc.stdint cimport (uint8_t, uint16_t, uint32_t, uint64_t,
16-
int32_t, int64_t)
16+
int32_t, int64_t, INT64_MAX, SIZE_MAX)
1717
from ._bounded_integers cimport (_rand_bool, _rand_int32, _rand_int64,
1818
_rand_int16, _rand_int8, _rand_uint64, _rand_uint32, _rand_uint16,
1919
_rand_uint8, _gen_mask)
@@ -126,9 +126,38 @@ cdef extern from "include/distributions.h":
126126
void random_multinomial(bitgen_t *bitgen_state, int64_t n, int64_t *mnix,
127127
double *pix, np.npy_intp d, binomial_t *binomial) nogil
128128

129+
int random_mvhg_count(bitgen_t *bitgen_state,
130+
int64_t total,
131+
size_t num_colors, int64_t *colors,
132+
int64_t nsample,
133+
size_t num_variates, int64_t *variates) nogil
134+
void random_mvhg_marginals(bitgen_t *bitgen_state,
135+
int64_t total,
136+
size_t num_colors, int64_t *colors,
137+
int64_t nsample,
138+
size_t num_variates, int64_t *variates) nogil
139+
129140
np.import_array()
130141

131142

143+
cdef int64_t _safe_sum_nonneg_int64(size_t num_colors, int64_t *colors):
144+
"""
145+
B41A Sum the values in the array `colors`.
146+
147+
Return -1 if an overflow occurs.
148+
The values in *colors are assumed to be nonnegative.
149+
"""
150+
cdef size_t i
151+
cdef int64_t sum
152+
153+
sum = 0
154+
for i in range(num_colors):
155+
if colors[i] > INT64_MAX - sum:
156+
return -1
157+
sum += colors[i]
158+
return sum
159+
160+
132161
cdef bint _check_bit_generator(object bitgen):
133162
"""Check if an object satisfies the BitGenerator interface.
134163
"""
@@ -3241,6 +3270,8 @@ cdef class Generator:
32413270
32423271
See Also
32433272
--------
3273+
multivariate_hypergeometric : Draw samples from the multivariate
3274+
hypergeometric distribution.
32443275
scipy.stats.hypergeom : probability density function, distribution or
32453276
cumulative density function, etc.
32463277
@@ -3739,6 +3770,222 @@ cdef class Generator:
37393770

37403771
return multin
37413772

3773+
def multivariate_hypergeometric(self, object colors, object nsample,
3774+
size=None, method='marginals'):
3775+
"""
3776+
multivariate_hypergeometric(colors, nsample, size=None,
3777+
method='marginals')
3778+
3779+
Generate variates from a multivariate hypergeometric distribution.
3780+
3781+
The multivariate hypergeometric distribution is a generalization
3782+
of the hypergeometric distribution.
3783+
3784+
Choose ``nsample`` items at random without replacement from a
3785+
collection with ``N`` distinct types. ``N`` is the length of
3786+
``colors``, and the values in ``colors`` are the number of occurrences
3787+
of that type in the collection. The total number of items in the
3788+
collection is ``sum(colors)``. Each random variate generated by this
3789+
function is a vector of length ``N`` holding the counts of the
3790+
different types that occurred in the ``nsample`` items.
3791+
3792+
The name ``colors`` comes from a common description of the
3793+
distribution: it is the probability distribution of the number of
3794+
marbles of each color selected without replacement from an urn
3795+
containing marbles of different colors; ``colors[i]`` is the number
3796+
of marbles in the urn with color ``i``.
3797+
3798+
Parameters
3799+
----------
3800+
colors : sequence of integers
3801+
The number of each type of item in the collection from which
3802+
a sample is drawn. The values in ``colors`` must be nonnegative.
3803+
To avoid loss of precision in the algorithm, ``sum(colors)``
3804+
must be less than ``10**9`` when `method` is "marginals".
3805+
nsample : int
3806+
The number of items selected. ``nsample`` must not be greater
3807+
than ``sum(colors)``.
3808+
size : int or tuple of ints, optional
3809+
The number of variates to generate, either an integer or a tuple
3810+
holding the shape of the array of variates. If the given size is,
3811+
e.g., ``(k, m)``, then ``k * m`` variates are drawn, where one
3812+
variate is a vector of length ``len(colors)``, and the return value
3813+
has shape ``(k, m, len(colors))``. If `size` is an integer, the
3814+
output has shape ``(size, len(colors))``. Default is None, in
< 10000 /td>3815+
which case a single variate is returned as an array with shape
3816+
``(len(colors),)``.
3817+
method : string, optional
3818+
Specify the algorithm that is used to generate the variates.
3819+
Must be 'count' or 'marginals' (the default). See the Notes
3820+
for a description of the methods.
3821+
3822+
Returns
3823+
-------
3824+
variates : ndarray
3825+
Array of variates drawn from the multivariate hypergeometric
3826+
distribution.
3827+
3828+
See Also
3829+
--------
3830+
hypergeometric : Draw samples from the (univariate) hypergeometric
3831+
distribution.
3832+
3833+
Notes
3834+
-----
3835+
The two methods do not return the same sequence of variates.
3836+
3837+
The "count" algorithm is roughly equivalent to the following numpy
3838+
code::
3839+
3840+
choices = np.repeat(np.arange(len(colors)), colors)
3841+
selection = np.random.choice(choices, nsample, replace=False)
3842+
variate = np.bincount(selection, minlength=len(colors))
3843+
3844+
The "count" algorithm uses a temporary array of integers with length
3845+
``sum(colors)``.
3846+
3847+
The "marginals" algorithm generates a variate by using repeated
3848+
calls to the univariate hypergeometric sampler. It is roughly
3849+
equivalent to::
3850+
3851+
variate = np.zeros(len(colors), dtype=np.int64)
3852+
# `remaining` is the cumulative sum of `colors` from the last
3853+
# element to the first; e.g. if `colors` is [3, 1, 5], then
3854+
# `remaining` is [9, 6, 5].
3855+
remaining = np.cumsum(colors[::-1])[::-1]
3856+
for i in range(len(colors)-1):
3857+
if nsample < 1:
3858+
break
3859+
variate[i] = hypergeometric(colors[i], remaining[i+1],
3860+
nsample)
3861+
nsample -= variate[i]
3862+
variate[-1] = nsample
3863+
3864+
The default method is "marginals". For some cases (e.g. when
3865+
`colors` contains relatively small integers), the "count" method
3866+
can be significantly faster than the "marginals" method. If
3867+
performance of the algorithm is important, test the two methods
3868+
with typical inputs to decide which works best.
3869+
3870+
.. versionadded:: 1.18.0
3871+
3872+
Examples
3873+
--------
3874+
>>> colors = [16, 8, 4]
3875+
>>> seed = 4861946401452
3876+
>>> gen = np.random.Generator(np.random.PCG64(seed))
3877+
>>> gen.multivariate_hypergeometric(colors, 6)
3878+
array([5, 0, 1])
3879+
>>> gen.multivariate_hypergeometric(colors, 6, size=3)
3880+
array([[5, 0, 1],
3881+
[2, 2, 2],
3882+
[3, 3, 0]])
3883+
>>> gen.multivariate_hypergeometric(colors, 6, size=(2, 2))
3884+
array([[[3, 2, 1],
3885+
[3, 2, 1]],
3886+
[[4, 1, 1],
3887+
[3, 2, 1]]])
3888+
"""
3889+
cdef int64_t nsamp
3890+
cdef size_t num_colors
3891+
cdef int64_t total
3892+
cdef int64_t *colors_ptr
3893+
cdef int64_t max_index
3894+
cdef size_t num_variates
3895+
cdef int64_t *variates_ptr
3896+
cdef int result
3897+
3898+
if method not in ['count', 'marginals']:
3899+
raise ValueError('method must be "count" or "marginals".')
3900+
3901+
try:
3902+
operator.index(nsample)
3903+
except TypeError:
3904+
raise ValueError('nsample must be an integer')
3905+
3906+
if nsample < 0:
3907+
raise ValueError("nsample must be nonnegative.")
3908+
if nsample > INT64_MAX:
3909+
raise ValueError("nsample must not exceed %d" % INT64_MAX)
3910+
nsamp = nsample
3911+
3912+
# Validation of colors, a 1-d sequence of nonnegative integers.
3913+
invalid_colors = False
3914+
try:
3915+
colors = np.asarray(colors)
3916+
if colors.ndim != 1:
3917+
invalid_colors = True
3918+
elif colors.size > 0 and not np.issubdtype(colors.dtype,
3919+
np.integer):
3920+
invalid_colors = True
3921+
elif np.any((colors < 0) | (colors > INT64_MAX)):
3922+
invalid_colors = True
3923+
except ValueError:
3924+
invalid_colors = True
3925+
if invalid_colors:
3926+
raise ValueError('colors must be a one-dimensional sequence '
3927+
'of nonnegative integers not exceeding %d.' %
3928+
INT64_MAX)
3929+
3930+
colors = np.ascontiguousarray(colors, dtype=np.int64)
3931+
num_colors = colors.size
3932+
3933+
colors_ptr = <int64_t *> np.PyArray_DATA(colors)
3934+
3935+
total = _safe_sum_nonneg_int64(num_colors, colors_ptr)
3936+
if total == -1:
3937+
raise ValueError("sum(colors) must not exceed the maximum value "
3938+
"of a 64 bit signed integer (%d)" % INT64_MAX)
3939+
3940+
if method == 'marginals' and total >= 1000000000:
3941+
raise ValueError('When method is "marginals", sum(colors) must '
3942+
'be less than 1000000000.')
3943+
3944+
# The C code that implements the 'count' method will malloc an
3945+
# array of size total*sizeof(size_t). Here we ensure that that
3946+
# product does not overflow.
3947+
if SIZE_MAX > <uint64_t>INT64_MAX:
3948+
max_index = INT64_MAX // sizeof(size_t)
3949+
else:
3950+
max_index = SIZE_MAX // sizeof(size_t)
3951+
if method == 'count' and total > max_index:
3952+
raise ValueError("When method is 'count', sum(colors) must not "
3953+
"exceed %d" % max_index)
3954+
if nsamp > total:
3955+
raise ValueError("nsample > sum(colors)")
3956+
3957+
# Figure out the shape of the return array.
3958+
if size is None:
3959+
shape = (num_colors,)
3960+
elif np.isscalar(size):
3961+
shape = (size, num_colors)
3962+
else:
3963+
shape = tuple(size) + (num_colors,)
3964+
variates = np.zeros(shape, dtype=np.int64)
3965+
3966+
if num_colors == 0:
3967+
return variates
3968+
3969+
# One variate is a vector of length num_colors.
3970+
num_variates = variates.size // num_colors
3971+
variates_ptr = <int64_t *> np.PyArray_DATA(variates)
3972+
3973+
if method == 'count':
3974+
with self.lock, nogil:
3975+
result = random_mvhg_count(&self._bitgen, total,
3976+
num_colors, colors_ptr, nsamp,
3977+
num_variates, variates_ptr)
3978+
if result == -1:
3979+
raise MemoryError("Insufficent memory for multivariate_"
3980+
"hypergeometric with method='count' and "
3981+
"sum(colors)=%d" % total)
3982+
else:
3983+
with self.lock, nogil:
3984+
random_mvhg_marginals(&self._bitgen, total,
3985+
num_colors, colors_ptr, nsamp,
3986+
num_variates, variates_ptr)
3987+
return variates
3988+
37423989
def dirichlet(self, object alpha, size=None):
37433990
"""
37443991
dirichlet(alpha, size=None)

numpy/random/include/distributions.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,20 @@ DECLDIR void random_bounded_bool_fill(bitgen_t *bitgen_state, npy_bool off,
170170
DECLDIR void random_multinomial(bitgen_t *bitgen_state, RAND_INT_TYPE n, RAND_INT_TYPE *mnix,
171171
double *pix, npy_intp d, binomial_t *binomial);
172172

173+
/* multivariate hypergeometric, "count" method */
174+
DECLDIR int random_mvhg_count(bitgen_t *bitgen_state,
175+
int64_t total,
176+
size_t num_colors, int64_t *colors,
177+
int64_t nsample,
178+
size_t num_variates, int64_t *variates);
179+
180+
/* multivariate hypergeometric, "marginals" method */
181+
DECLDIR void random_mvhg_marginals(bitgen_t *bitgen_state,
182+
int64_t total,
183+
size_t num_colors, int64_t *colors,
184+
int64_t nsample,
185+
size_t num_variates, int64_t *variates);
186+
173187
/* Common to legacy-distributions.c and distributions.c but not exported */
174188

175189
RAND_INT_TYPE random_binomial_btpe(bitgen_t *bitgen_state,

numpy/random/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def generate_libraries(ext, build_dir):
100100
other_srcs = [
101101
'src/distributions/logfactorial.c',
102102
'src/distributions/distributions.c',
103+
'src/distributions/random_mvhg_count.c',
104+
'src/distributions/random_mvhg_marginals.c',
103105
'src/distributions/random_hypergeometric.c',
104106
]
105107
for gen in ['_generator', '_bounded_integers']:
@@ -114,7 +116,6 @@ def generate_libraries(ext, build_dir):
114116
define_macros=defs,
115117
)
116118
config.add_extension('mtrand',
117-
# mtrand does not depend on random_hypergeometric.c.
118119
sources=['mtrand.c',
119120
'src/legacy/legacy-distributions.c',
120121
'src/distributions/logfactorial.c',

0 commit comments

Comments
 (0)
0