8000 Merge pull request #11383 from mattip/recreate-8717 · numpy/numpy@464f79e · GitHub
[go: up one dir, main page]

Skip to content

Commit 464f79e

Browse files
authored
Merge pull request #11383 from mattip/recreate-8717
ENH: Allow size=0 in numpy.random.choice
2 parents fca7a15 + 9013b0f commit 464f79e

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

doc/release/1.16.0-notes.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ New Features
3434
Improvements
3535
============
3636

37+
``randint`` and ``choice`` now work on empty distributions
38+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
39+
Even when no elements needed to be drawn, ``np.random.randint`` and
40+
``np.random.choice`` raised an error when the arguments described an empty
41+
distribution. This has been fixed so that e.g.
42+
``np.random.choice([], 0) == np.array([], dtype=float64)``.
3743

3844
Changes
3945
=======

numpy/random/mtrand/mtrand.pyx

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2323

2424
include "Python.pxi"
25-
include "randint_helpers.pxi"
2625
include "numpy.pxd"
26+
include "randint_helpers.pxi"
2727
include "cpython/pycapsule.pxd"
2828

2929
from libc cimport string
@@ -988,9 +988,9 @@ cdef class RandomState:
988988
raise ValueError("low is out of bounds for %s" % dtype)
989989
if ihigh > highbnd:
990990
raise ValueError("high is out of bounds for %s" % dtype)
991-
if ilow >= ihigh:
992-
raise ValueError("low >= high")
993-
991+
if ilow >= ihigh and np.prod(size) != 0:
992+
raise ValueError("Range cannot be empty (low >= high) unless no samples are taken")
993+
994994
with self.lock:
995995
ret = randfunc(ilow, ihigh - 1, size, self.state_address)
996996

@@ -1114,15 +1114,15 @@ cdef class RandomState:
11141114
# __index__ must return an integer by python rules.
11151115
pop_size = operator.index(a.item())
11161116
except TypeError:
1117-
raise ValueError("a must be 1-dimensional or an integer")
1118-
if pop_size <= 0:
1119-
raise ValueError("a must be greater than 0")
1117+
raise ValueError("'a' must be 1-dimensional or an integer")
1118+
if pop_size <= 0 and np.prod(size) != 0:
1119+
raise ValueError("'a' must be greater than 0 unless no samples are taken")
11201120
elif a.ndim != 1:
1121-
raise ValueError("a must be 1-dimensional")
1121+
raise ValueError("'a' must be 1-dimensional")
11221122
else:
11231123
pop_size = a.shape[0]
1124-
if pop_size is 0:
1125-
raise ValueError("a must be non-empty")
1124+
if pop_size is 0 and np.prod(size) != 0:
1125+
raise ValueError("'a' cannot be empty unless no samples are taken")
11261126

11271127
if p is not None:
11281128
d = len(p)
@@ -1136,9 +1136,9 @@ cdef class RandomState:
11361136
pix = <double*>PyArray_DATA(p)
11371137

11381138
if p.ndim != 1:
1139-
raise ValueError("p must be 1-dimensional")
1139+
raise ValueError("'p' must be 1-dimensional")
11401140
if p.size != pop_size:
1141-
raise ValueError("a and p must have same size")
1141+
raise ValueError("'a' and 'p' must have same size")
11421142
if np.logical_or.reduce(p < 0):
11431143
raise ValueError("probabilities are not non-negative")
11441144
if abs(kahan_sum(pix, d) - 1.) > atol:

numpy/random/mtrand/randint_helpers.pxi.in

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_dispatch(dtypes):
2323

2424
{{for npy_dt, npy_udt, np_dt in get_dispatch(dtypes)}}
2525

26-
def _rand_{{npy_dt}}(low, high, size, rngstate):
26+
def _rand_{{npy_dt}}(npy_{{npy_dt}} low, npy_{{npy_dt}} high, size, rngstate):
2727
"""
2828
_rand_{{npy_dt}}(low, high, size, rngstate)
2929

@@ -60,8 +60,8 @@ def _rand_{{npy_dt}}(low, high, size, rngstate):
6060
cdef npy_intp cnt
6161
cdef rk_state *state = <rk_state *>PyCapsule_GetPointer(rngstate, NULL)
6262

63-
rng = <npy_{{npy_udt}}>(high - low)
64-
off = <npy_{{npy_udt}}>(<npy_{{npy_dt}}>low)
63+
off = <npy_{{npy_udt}}>(low)
64+
rng = <npy_{{npy_udt}}>(high) - <npy_{{npy_udt} 57A7 }>(low)
6565

6666
if size is None:
6767
rk_random_{{npy_udt}}(off, rng, 1, &buf, state)

numpy/random/tests/test_random.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,15 @@ def test_choice_return_shape(self):
440440
assert_equal(np.random.choice(6, s, replace=False, p=p).shape, s)
441441
assert_equal(np.random.choice(np.arange(6), s, replace=True).shape, s)
442442

443+
# Check zero-size
444+
assert_equal(np.random.randint(0, 0, size=(3, 0, 4)).shape, (3, 0, 4))
445+
assert_equal(np.random.randint(0, -10, size=0).shape, (0,))
446+
assert_equal(np.random.randint(10, 10, size=0).shape, (0,))
447+
assert_equal(np.random.choice(0, size=0).shape, (0,))
448+
assert_equal(np.random.choice([], size=(0,)).shape, (0,))
449+
assert_equal(np.random.choice(['a', 'b'], size=(3, 0, 4)).shape, (3, 0, 4))
450+
assert_raises(ValueError, np.random.choice, [], 10)
451+
443452
def test_bytes(self):
444453
np.random.seed(self.seed)
445454
actual = np.random.bytes(10)

0 commit comments

Comments
 (0)
0