10000 ENH: make random.choice size argument default to None and allow tuple · numpy/numpy@acf7421 · GitHub
[go: up one dir, main page]

Skip to content

Commit acf7421

Browse files
sebergcertik
authored andcommitted
ENH: make random.choice size argument default to None and allow tuple
The size argument to random.choice should work like it does for all other functions in random as well.
1 parent bf084bd commit acf7421

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

numpy/random/mtrand/mtrand.pyx

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ cdef class RandomState:
916916
return bytestring
917917

918918

919-
def choice(self, a, size=1, replace=True, p=None):
919+
def choice(self, a, size=None, replace=True, p=None):
920920
"""
921921
choice(a, size=1, replace=True, p=None)
922922
@@ -929,8 +929,9 @@ cdef class RandomState:
929929
a : 1-D array-like or int
930930
If an ndarray, a random sample is generated from its elements.
931931
If an int, the random sample is generated as if a was np.arange(n)
932-
size : int
933-
Positive integer, the size of the sample.
932+
size : int or tuple of ints, optional
933+
Output shape. Default is None, in which case a single value is
934+
returned.
934935
replace : boolean, optional
935936
Whether the sample is with or without replacement
936937
p : 1-D array-like, optional
@@ -1017,26 +1018,30 @@ cdef class RandomState:
10171018
if not np.allclose(p.sum(), 1):
10181019
raise ValueError("probabilities do not sum to 1")
10191020

1021+
shape = size if size is not None else tuple()
1022+
size = np.prod(shape, dtype=np.intp)
1023+
10201024
# Actual sampling
10211025
if replace:
10221026
if None != p:
10231027
cdf = p.cumsum()
10241028
cdf /= cdf[-1]
1025-
uniform_samples = np.random.random(size)
1029+
uniform_samples = np.random.random(shape)
10261030
idx = cdf.searchsorted(uniform_samples, side='right')
10271031
else:
1028-
idx = self.randint(0, pop_size, size=size)
1032+
idx = self.randint(0, pop_size, size=shape)
10291033
else:
10301034
if size > pop_size:
1031-
raise ValueError(''.join(["Cannot take a larger sample than ",
1032-
"population when 'replace=False'"]))
1035+
raise ValueError("Cannot take a larger sample than "
1036+
"population when 'replace=False'")
10331037

10341038
if None != p:
10351039
if np.sum(p > 0) < size:
10361040
raise ValueError("Fewer non-zero entries in p than size")
10371041
n_uniq = 0
10381042
p = p.copy()
1039-
found = np.zeros(size, dtype=np.int)
1043+
found = np.zeros(shape, dtype=np.int)
1044+
flat_found = found.ravel()
10401045
while n_uniq < size:
10411046
x = self.rand(size - n_uniq)
10421047
if n_uniq > 0:
@@ -1045,17 +1050,18 @@ cdef class RandomState:
10451050
cdf /= cdf[-1]
10461051
new = cdf.searchsorted(x, side='right')
10471052
new = np.unique(new)
1048-
found[n_uniq:n_uniq + new.size] = new
1053+
flat_found[n_uniq:n_uniq + new.size] = new
10491054
n_uniq += new.size
10501055
idx = found
10511056
else:
10521057
idx = self.permutation(pop_size)[:size]
1058+
idx.shape = shape
10531059

10541060
#Use samples as indices for a if a is array-like
10551061
if isinstance(a, int):
10561062
return idx
10571063
else:
1058-
return a.take(idx)
1064+
return a[idx]
10591065

10601066

10611067
def uniform(self, low=0.0, high=1.0, size=None):

0 commit comments

Comments
 (0)
0