@@ -13,7 +13,7 @@ from numpy.core.multiarray import normalize_axis_index
13
13
14
14
from libc cimport string
15
15
from libc .stdint cimport (uint8_t , uint16_t , uint32_t , uint64_t ,
16
- int32_t , int64_t )
16
+ int32_t , int64_t , INT64_MAX , SIZE_MAX )
17
17
from ._bounded_integers cimport (_rand_bool , _rand_int32 , _rand_int64 ,
18
18
_rand_int16 , _rand_int8 , _rand_uint64 , _rand_uint32 , _rand_uint16 ,
19
19
_rand_uint8 , _gen_mask )
@@ -126,9 +126,38 @@ cdef extern from "include/distributions.h":
126
126
void random_multinomial (bitgen_t * bitgen_state , int64_t n , int64_t * mnix ,
127
127
double * pix , np .npy_intp d , binomial_t * binomial ) nogil
128
128
129
+ int random_mvhg_count (bitgen_t * bitgen_state ,
130
+ int64_t total ,
131
+ size_t num_colors , int64_t * colors ,
132
+ int64_t nsample ,
133
+ size_t num_variates , int64_t * variates ) nogil
134
+ void random_mvhg_marginals (bitgen_t * bitgen_state ,
135
+ int64_t total ,
136
+ size_t num_colors , int64_t * colors ,
137
+ int64_t nsample ,
138
+ size_t num_variates , int64_t * variates ) nogil
139
+
129
140
np .import_array ()
130
141
131
142
143
+ cdef int64_t _safe_sum_nonneg_int64 (size_t num_colors , int64_t * colors ):
144
+ """
145
+
B41A
Sum the values in the array `colors`.
146
+
147
+ Return -1 if an overflow occurs.
148
+ The values in *colors are assumed to be nonnegative.
149
+ """
150
+ cdef size_t i
151
+ cdef int64_t sum
152
+
153
+ sum = 0
154
+ for i in range (num_colors ):
155
+ if colors [i ] > INT64_MAX - sum :
156
+ return - 1
157
+ sum += colors [i ]
158
+ return sum
159
+
160
+
132
161
cdef bint _check_bit_generator (object bitgen ):
133
162
"""Check if an object satisfies the BitGenerator interface.
134
163
"""
@@ -3241,6 +3270,8 @@ cdef class Generator:
3241
3270
3242
3271
See Also
3243
3272
--------
3273
+ multivariate_hypergeometric : Draw samples from the multivariate
3274
+ hypergeometric distribution.
3244
3275
scipy.stats.hypergeom : probability density function, distribution or
3245
3276
cumulative density function, etc.
3246
3277
@@ -3739,6 +3770,222 @@ cdef class Generator:
3739
3770
3740
3771
return multin
3741
3772
3773
+ def multivariate_hypergeometric (self , object colors , object nsample ,
3774
+ size = None , method = 'marginals' ):
3775
+ """
3776
+ multivariate_hypergeometric(colors, nsample, size=None,
3777
+ method='marginals')
3778
+
3779
+ Generate variates from a multivariate hypergeometric distribution.
3780
+
3781
+ The multivariate hypergeometric distribution is a generalization
3782
+ of the hypergeometric distribution.
3783
+
3784
+ Choose ``nsample`` items at random without replacement from a
3785
+ collection with ``N`` distinct types. ``N`` is the length of
3786
+ ``colors``, and the values in ``colors`` are the number of occurrences
3787
+ of that type in the collection. The total number of items in the
3788
+ collection is ``sum(colors)``. Each random variate generated by this
3789
+ function is a vector of length ``N`` holding the counts of the
3790
+ different types that occurred in the ``nsample`` items.
3791
+
3792
+ The name ``colors`` comes from a common description of the
3793
+ distribution: it is the probability distribution of the number of
3794
+ marbles of each color selected without replacement from an urn
3795
+ containing marbles of different colors; ``colors[i]`` is the number
3796
+ of marbles in the urn with color ``i``.
3797
+
3798
+ Parameters
3799
+ ----------
3800
+ colors : sequence of integers
3801
+ The number of each type of item in the collection from which
3802
+ a sample is drawn. The values in ``colors`` must be nonnegative.
3803
+ To avoid loss of precision in the algorithm, ``sum(colors)``
3804
+ must be less than ``10**9`` when `method` is "marginals".
3805
+ nsample : int
3806
+ The number of items selected. ``nsample`` must not be greater
3807
+ than ``sum(colors)``.
3808
+ size : int or tuple of ints, optional
3809
+ The number of variates to generate, either an integer or a tuple
3810
+ holding the shape of the array of variates. If the given size is,
3811
+ e.g., ``(k, m)``, then ``k * m`` variates are drawn, where one
3812
+ variate is a vector of length ``len(colors)``, and the return value
3813
+ has shape ``(k, m, len(colors))``. If `size` is an integer, the
3814
+ output has shape ``(size, len(colors))``. Default is None, in
<
10000
/td>3815
+ which case a single variate is returned as an array with shape
3816
+ ``(len(colors),)``.
3817
+ method : string, optional
3818
+ Specify the algorithm that is used to generate the variates.
3819
+ Must be 'count' or 'marginals' (the default). See the Notes
3820
+ for a description of the methods.
3821
+
3822
+ Returns
3823
+ -------
3824
+ variates : ndarray
3825
+ Array of variates drawn from the multivariate hypergeometric
3826
+ distribution.
3827
+
3828
+ See Also
3829
+ --------
3830
+ hypergeometric : Draw samples from the (univariate) hypergeometric
3831
+ distribution.
3832
+
3833
+ Notes
3834
+ -----
3835
+ The two methods do not return the same sequence of variates.
3836
+
3837
+ The "count" algorithm is roughly equivalent to the following numpy
3838
+ code::
3839
+
3840
+ choices = np.repeat(np.arange(len(colors)), colors)
3841
+ selection = np.random.choice(choices, nsample, replace=False)
3842
+ variate = np.bincount(selection, minlength=len(colors))
3843
+
3844
+ The "count" algorithm uses a temporary array of integers with length
3845
+ ``sum(colors)``.
3846
+
3847
+ The "marginals" algorithm generates a variate by using repeated
3848
+ calls to the univariate hypergeometric sampler. It is roughly
3849
+ equivalent to::
3850
+
3851
+ variate = np.zeros(len(colors), dtype=np.int64)
3852
+ # `remaining` is the cumulative sum of `colors` from the last
3853
+ # element to the first; e.g. if `colors` is [3, 1, 5], then
3854
+ # `remaining` is [9, 6, 5].
3855
+ remaining = np.cumsum(colors[::-1])[::-1]
3856
+ for i in range(len(colors)-1):
3857
+ if nsample < 1:
3858
+ break
3859
+ variate[i] = hypergeometric(colors[i], remaining[i+1],
3860
+ nsample)
3861
+ nsample -= variate[i]
3862
+ variate[-1] = nsample
3863
+
3864
+ The default method is "marginals". For some cases (e.g. when
3865
+ `colors` contains relatively small integers), the "count" method
3866
+ can be significantly faster than the "marginals" method. If
3867
+ performance of the algorithm is important, test the two methods
3868
+ with typical inputs to decide which works best.
3869
+
3870
+ .. versionadded:: 1.18.0
3871
+
3872
+ Examples
3873
+ --------
3874
+ >>> colors = [16, 8, 4]
3875
+ >>> seed = 4861946401452
3876
+ >>> gen = np.random.Generator(np.random.PCG64(seed))
3877
+ >>> gen.multivariate_hypergeometric(colors, 6)
3878
+ array([5, 0, 1])
3879
+ >>> gen.multivariate_hypergeometric(colors, 6, size=3)
3880
+ array([[5, 0, 1],
3881
+ [2, 2, 2],
3882
+ [3, 3, 0]])
3883
+ >>> gen.multivariate_hypergeometric(colors, 6, size=(2, 2))
3884
+ array([[[3, 2, 1],
3885
+ [3, 2, 1]],
3886
+ [[4, 1, 1],
3887
+ [3, 2, 1]]])
3888
+ """
3889
+ cdef int64_t nsamp
3890
+ cdef size_t num_colors
3891
+ cdef int64_t total
3892
+ cdef int64_t * colors_ptr
3893
+ cdef int64_t max_index
3894
+ cdef size_t num_variates
3895
+ cdef int64_t * variates_ptr
3896
+ cdef int result
3897
+
3898
+ if method not in ['count' , 'marginals' ]:
3899
+ raise ValueError ('method must be "count" or "marginals".' )
3900
+
3901
+ try :
3902
+ operator .index (nsample )
3903
+ except TypeError :
3904
+ raise ValueError ('nsample must be an integer' )
3905
+
3906
+ if nsample < 0 :
3907
+ raise ValueError ("nsample must be nonnegative." )
3908
+ if nsample > INT64_MAX :
3909
+ raise ValueError ("nsample must not exceed %d" % INT64_MAX )
3910
+ nsamp = nsample
3911
+
3912
+ # Validation of colors, a 1-d sequence of nonnegative integers.
3913
+ invalid_colors = False
3914
+ try :
3915
+ colors = np .asarray (colors )
3916
+ if colors .ndim != 1 :
3917
+ invalid_colors = True
3918
+ elif colors .size > 0 and not np .issubdtype (colors .dtype ,
3919
+ np .integer ):
3920
+ invalid_colors = True
3921
+ elif np .any ((colors < 0 ) | (colors > INT64_MAX )):
3922
+ invalid_colors = True
3923
+ except ValueError :
3924
+ invalid_colors = True
3925
+ if invalid_colors :
3926
+ raise ValueError ('colors must be a one-dimensional sequence '
3927
+ 'of nonnegative integers not exceeding %d.' %
3928
+ INT64_MAX )
3929
+
3930
+ colors = np .ascontiguousarray (colors , dtype = np .int64 )
3931
+ num_colors = colors .size
3932
+
3933
+ colors_ptr = < int64_t * > np .PyArray_DATA (colors )
3934
+
3935
+ total = _safe_sum_nonneg_int64 (num_colors , colors_ptr )
3936
+ if total == - 1 :
3937
+ raise ValueError ("sum(colors) must not exceed the maximum value "
3938
+ "of a 64 bit signed integer (%d)" % INT64_MAX )
3939
+
3940
+ if method == 'marginals' and total >= 1000000000 :
3941
+ raise ValueError ('When method is "marginals", sum(colors) must '
3942
+ 'be less than 1000000000.' )
3943
+
3944
+ # The C code that implements the 'count' method will malloc an
3945
+ # array of size total*sizeof(size_t). Here we ensure that that
3946
+ # product does not overflow.
3947
+ if SIZE_MAX > < uint64_t > INT64_MAX :
3948
+ max_index = INT64_MAX // sizeof (size_t )
3949
+ else :
3950
+ max_index = SIZE_MAX // sizeof (size_t )
3951
+ if method == 'count' and total > max_index :
3952
+ raise ValueError ("When method is 'count', sum(colors) must not "
3953
+ "exceed %d" % max_index )
3954
+ if nsamp > total :
3955
+ raise ValueError ("nsample > sum(colors)" )
3956
+
3957
+ # Figure out the shape of the return array.
3958
+ if size is None :
3959
+ shape = (num_colors ,)
3960
+ elif np .isscalar (size ):
3961
+ shape = (size , num_colors )
3962
+ else :
3963
+ shape = tuple (size ) + (num_colors ,)
3964
+ variates = np .zeros (shape , dtype = np .int64 )
3965
+
3966
+ if num_colors == 0 :
3967
+ return variates
3968
+
3969
+ # One variate is a vector of length num_colors.
3970
+ num_variates = variates .size // num_colors
3971
+ variates_ptr = < int64_t * > np .PyArray_DATA (variates )
3972
+
3973
+ if method == 'count' :
3974
+ with self .lock , nogil :
3975
+ result = random_mvhg_count (& self ._bitgen , total ,
3976
+ num_colors , colors_ptr , nsamp ,
3977
+ num_variates , variates_ptr )
3978
+ if result == - 1 :
3979
+ raise MemoryError ("Insufficent memory for multivariate_"
3980
+ "hypergeometric with method='count' and "
3981
+ "sum(colors)=%d" % total )
3982
+ else :
3983
+ with self .lock , nogil :
3984
+ random_mvhg_marginals (& self ._bitgen , total ,
3985
+ num_colors , colors_ptr , nsamp ,
3986
+ num_variates , variates_ptr )
3987
+ return variates
3988
+
3742
3989
def dirichlet (self , object alpha , size = None ):
3743
3990
"""
3744
3991
dirichlet(alpha, size=None)
0 commit comments