@@ -916,7 +916,7 @@ cdef class RandomState:
916916 return bytestring
917917
918918
919- def choice (self , a , size = 1 , replace = True , p = None ):
919+ def choice (self , a , size = None , replace = True , p = None ):
920920 """
921921 choice(a, size=1, replace=True, p=None)
922922
@@ -929,8 +929,9 @@ cdef class RandomState:
929929 a : 1-D array-like or int
930930 If an ndarray, a random sample is generated from its elements.
931931 If an int, the random sample is generated as if a was np.arange(n)
932- size : int
933- Positive integer, the size of the sample.
932+ size : int or tuple of ints, optional
933+ Output shape. Default is None, in which case a single value is
934+ returned.
934935 replace : boolean, optional
935936 Whether the sample is with or without replacement
936937 p : 1-D array-like, optional
@@ -1017,26 +1018,30 @@ cdef class RandomState:
10171018 if not np.allclose(p.sum(), 1 ):
10181019 raise ValueError (" probabilities do not sum to 1" )
10191020
1021+ shape = size if size is not None else tuple ()
1022+ size = np.prod(shape, dtype = np.intp)
1023+
10201024 # Actual sampling
10211025 if replace:
10221026 if None != p:
10231027 cdf = p.cumsum()
10241028 cdf /= cdf[- 1 ]
1025- uniform_samples = np.random.random(size )
1029+ uniform_samples = np.random.random(shape )
10261030 idx = cdf.searchsorted(uniform_samples, side = ' right' )
10271031 else :
1028- idx = self .randint(0 , pop_size, size = size )
1032+ idx = self .randint(0 , pop_size, size = shape )
10291033 else :
10301034 if size > pop_size:
1031- raise ValueError (' ' .join([ " Cannot take a larger sample than " ,
1032- " population when 'replace=False'" ]) )
1035+ raise ValueError (" Cannot take a larger sample than "
1036+ " population when 'replace=False'" )
10331037
10341038 if None != p:
10351039 if np.sum(p > 0 ) < size:
10361040 raise ValueError (" Fewer non-zero entries in p than size" )
10371041 n_uniq = 0
10381042 p = p.copy()
1039- found = np.zeros(size, dtype = np.int)
1043+ found = np.zeros(shape, dtype = np.int)
1044+ flat_found = found.ravel()
10401045 while n_uniq < size:
10411046 x = self .rand(size - n_uniq)
10421047 if n_uniq > 0 :
@@ -1045,17 +1050,18 @@ cdef class RandomState:
10451050 cdf /= cdf[- 1 ]
10461051 new = cdf.searchsorted(x, side = ' right' )
10471052 new = np.unique(new)
1048- found [n_uniq:n_uniq + new.size] = new
1053+ flat_found [n_uniq:n_uniq + new.size] = new
10491054 n_uniq += new.size
10501055 idx = found
10511056 else :
10521057 idx = self .permutation(pop_size)[:size]
1058+ idx.shape = shape
10531059
10541060 # Use samples as indices for a if a is array-like
10551061 if isinstance (a, int ):
10561062 return idx
10571063 else :
1058- return a.take( idx)
1064+ return a[ idx]
10591065
10601066
10611067 def uniform (self , low = 0.0 , high = 1.0 , size = None ):
0 commit comments