From 4c8642f0e4d743ae80fc143ef826cc76b88f5188 Mon Sep 17 00:00:00 2001 From: Rick Nitsche Date: Wed, 6 Jun 2018 16:42:45 -0700 Subject: [PATCH] BUG: check for probabilities containing nan-value in random.choice --- numpy/random/mtrand/mtrand.pyx | 4 +++- numpy/random/tests/test_random.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index b45b3146f145..6aba88f18fc4 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -1142,7 +1142,9 @@ cdef class RandomState: raise ValueError("a and p must have same size") if np.logical_or.reduce(p < 0): raise ValueError("probabilities are not non-negative") - if abs(kahan_sum(pix, d) - 1.) > atol: + + # negated to handle NaNs in p + if not abs(kahan_sum(pix, d) - 1.) <= atol: raise ValueError("probabilities do not sum to 1") shape = size diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index 61c6e912dae6..53d0d67ae3f1 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -403,6 +403,8 @@ def test_choice_exceptions(self): assert_raises(ValueError, sample, [1, 2, 3], 4, replace=False) assert_raises(ValueError, sample, [1, 2, 3], 2, replace=False, p=[1, 0, 0]) + with np.errstate(invalid='ignore'): + assert_raises(ValueError, sample, [42, 2, 3], p=[None, None, None]) def test_choice_return_shape(self): p = [0.1, 0.9]