diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index f25f16a8a155..cd950b3fa8d5 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -3756,7 +3756,20 @@ cdef class Generator: pix = np.PyArray_DATA(parr) check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1) if kahan_sum(pix, d-1) > (1.0 + 1e-12): - raise ValueError("sum(pvals[:-1]) > 1.0") + # When floating, but not float dtype, and close, improve the error + # 1.0001 works for float16 and float32 + if (isinstance(pvals, np.ndarray) + and np.issubdtype(pvals.dtype, np.floating) + and pvals.dtype != float + and pvals.sum() < 1.0001): + msg = ("sum(pvals[:-1].astype(np.float64)) > 1.0. The pvals " + "array is cast to 64-bit floating point prior to " + "checking the sum. Precision changes when casting may " + "cause problems even if the sum of the original pvals " + "is valid.") + else: + msg = "sum(pvals[:-1]) > 1.0" + raise ValueError(msg) if np.PyArray_NDIM(on) != 0: # vector check_array_constraint(on, 'n', CONS_NON_NEGATIVE) diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index 6f44e271f245..4e12f8e59264 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -4232,7 +4232,20 @@ cdef class RandomState: pix = np.PyArray_DATA(parr) check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1) if kahan_sum(pix, d-1) > (1.0 + 1e-12): - raise ValueError("sum(pvals[:-1]) > 1.0") + # When floating, but not float dtype, and close, improve the error + # 1.0001 works for float16 and float32 + if (isinstance(pvals, np.ndarray) + and np.issubdtype(pvals.dtype, np.floating) + and pvals.dtype != float + and pvals.sum() < 1.0001): + msg = ("sum(pvals[:-1].astype(np.float64)) > 1.0. The pvals " + "array is cast to 64-bit floating point prior to " + "checking the sum. Precision changes when casting may " + "cause problems even if the sum of the original pvals " + "is valid.") + else: + msg = "sum(pvals[:-1]) > 1.0" + raise ValueError(msg) if size is None: shape = (d,) diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index f6bd985b4510..446b350dd8de 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -142,6 +142,14 @@ def test_multidimensional_pvals(self): assert_raises(ValueError, random.multinomial, 10, [[[0], [1]], [[1], [0]]]) assert_raises(ValueError, random.multinomial, 10, np.array([[0, 1], [1, 0]])) + def test_multinomial_pvals_float32(self): + x = np.array([9.9e-01, 9.9e-01, 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09, + 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32) + pvals = x / x.sum() + random = Generator(MT19937(1432985819)) + match = r"[\w\s]*pvals array is cast to 64-bit floating" + with pytest.raises(ValueError, match=match): + random.multinomial(1, pvals) class TestMultivariateHypergeometric: diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py index b16275b700e6..861813a95d1f 100644 --- a/numpy/random/tests/test_randomstate.py +++ b/numpy/random/tests/test_randomstate.py @@ -167,6 +167,14 @@ def test_p_non_contiguous(self): contig = random.multinomial(100, pvals=np.ascontiguousarray(pvals)) assert_array_equal(non_contig, contig) + def test_multinomial_pvals_float32(self): + x = np.array([9.9e-01, 9.9e-01, 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09, + 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32) + pvals = x / x.sum() + match = r"[\w\s]*pvals array is cast to 64-bit floating" + with pytest.raises(ValueError, match=match): + random.multinomial(1, pvals) + class TestSetState: def setup(self):