@@ -916,7 +916,7 @@ cdef class RandomState:
916
916
return bytestring
917
917
918
918
919
- def choice (self , a , size = 1 , replace = True , p = None ):
919
+ def choice (self , a , size = None , replace = True , p = None ):
920
920
"""
921
921
choice(a, size=1, replace=True, p=None)
922
922
@@ -929,8 +929,9 @@ cdef class RandomState:
929
929
a : 1-D array-like or int
930
930
If an ndarray, a random sample is generated from its elements.
931
931
If an int, the random sample is generated as if a was np.arange(n)
932
- size : int
933
- Positive integer, the size of the sample.
932
+ size : int or tuple of ints, optional
933
+ Output shape. Default is None, in which case a single value is
934
+ returned.
934
935
replace : boolean, optional
935
936
Whether the sample is with or without replacement
936
937
p : 1-D array-like, optional
@@ -1017,26 +1018,30 @@ cdef class RandomState:
1017
1018
if not np.allclose(p.sum(), 1 ):
1018
1019
raise ValueError (" probabilities do not sum to 1" )
1019
1020
1021
+ shape = size if size is not None else tuple ()
1022
+ size = np.prod(shape, dtype = np.intp)
1023
+
1020
1024
# Actual sampling
1021
1025
if replace:
1022
1026
if None != p:
1023
1027
cdf = p.cumsum()
1024
1028
cdf /= cdf[- 1 ]
1025
- uniform_samples = np.random.random(size )
1029
+ uniform_samples = np.random.random(shape )
1026
1030
idx = cdf.searchsorted(uniform_samples, side = ' right' )
1027
1031
else :
1028
- idx = self .randint(0 , pop_size, size = size )
1032
+ idx = self .randint(0 , pop_size, size = shape )
1029
1033
else :
1030
1034
if size > pop_size:
1031
- raise ValueError (' ' .join([ " Cannot take a larger sample than " ,
1032
- " population when 'replace=False'" ]) )
1035
+ raise ValueError (" Cannot take a larger sample than "
1036
+ " population when 'replace=False'" )
1033
1037
1034
1038
if None != p:
1035
1039
if np.sum(p > 0 ) < size:
1036
1040
raise ValueError (" Fewer non-zero entries in p than size" )
1037
1041
n_uniq = 0
1038
1042
p = p.copy()
1039
- found = np.zeros(size, dtype = np.int)
1043
+ found = np.zeros(shape, dtype = np.int)
1044
+ flat_found = found.ravel()
1040
1045
while n_uniq < size:
1041
1046
x = self .rand(size - n_uniq)
1042
1047
if n_uniq > 0 :
@@ -1045,17 +1050,18 @@ cdef class RandomState:
1045
1050
cdf /= cdf[- 1 ]
1046
1051
new = cdf.searchsorted(x, side = ' right' )
1047
1052
new = np.unique(new)
1048
- found [n_uniq:n_uniq + new.size] = new
1053
+ flat_found [n_uniq:n_uniq + new.size] = new
1049
1054
n_uniq += new.size
1050
1055
idx = found
1051
1056
else :
1052
1057
idx = self .permutation(pop_size)[:size]
1058
+ idx.shape = shape
1053
1059
1054
1060
# Use samples as indices for a if a is array-like
1055
1061
if isinstance (a, int ):
1056
1062
return idx
1057
1063
else :
1058
- return a.take( idx)
1064
+ return a[ idx]
1059
1065
1060
1066
1061
1067
def uniform (self , low = 0.0 , high = 1.0 , size = None ):
0 commit comments