8000 ENH: Improve error message in multinomial · numpy/numpy@359d04f · GitHub
[go: up one dir, main page]

Skip to content

Commit 359d04f

Browse files
committed
ENH: Improve error message in multinomial
Improve error message when the sum of pvals is larger than 1 when the input data is an ndarray
1 parent 9a8e3fc commit 359d04f

File tree

4 files changed

+18
-41
lines changed

4 files changed

+18
-41
lines changed

numpy/random/_common.pxd

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#cython: language_level=3
2-
from cython cimport floating
2+
33
from libc.stdint cimport uint32_t, uint64_t, int32_t, int64_t
44

55
import numpy as np
@@ -64,8 +64,6 @@ ctypedef int64_t (*random_int_2_i)(bitgen_t *state, int64_t a, int64_t b) nogil
6464

6565
cdef double kahan_sum(double *darr, np.npy_intp n)
6666

67-
cdef bint kahan_check(floating[::1] a, floating tol)
68-
6967
cdef inline double uint64_to_double(uint64_t rnd) nogil:
7068
return (rnd >> 11) * (1.0 / 9007199254740992.0)
7169

numpy/random/_common.pyx

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,20 +183,6 @@ cdef double kahan_sum(double *darr, np.npy_intp n):
183183
return sum
184184

185185

186-
cdef bint kahan_check(floating[::1] a, floating tol):
187-
cdef floating c, y, t, sum
188-
cdef np.npy_intp i, n
189-
n = a.shape[0]
190-
sum = a[0]
191-
c = 0.0
192-
for i in range(1, n):
193-
y = a[i] - c
194-
t = sum + y
195-
c = (t-sum) - y
196-
sum = t
197-
return sum > tol
198-
199-
200186
cdef object wrap_int(object val, object bits):
201187
"""Wraparound to place an integer into the interval [0, 2**bits)"""
202188
mask = ~(~int(0) << bits)

numpy/random/_generator.pyx

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ from numpy.random cimport bitgen_t
2525
from ._common cimport (POISSON_LAM_MAX, CONS_POSITIVE, CONS_NONE,
2626
CONS_NON_NEGATIVE, CONS_BOUNDED_0_1, CONS_BOUNDED_GT_0_1,
2727
CONS_GT_1, CONS_POSITIVE_NOT_NAN, CONS_POISSON,
28-
double_fill, cont, kahan_sum, kahan_check, cont_broadcast_3, float_fill, cont_f,
28+
double_fill, cont, kahan_sum, cont_broadcast_3, float_fill, cont_f,
2929
check_array_constraint, check_constraint, disc, discrete_broadcast_iii,
3030
validate_output_shape
3131
)
@@ -3725,11 +3725,7 @@ cdef class Generator:
37253725

37263726
cdef np.npy_intp d, i, sz, offset
37273727
cdef np.ndarray parr, mnarr, on, temp_arr
3728-
cdef float[::1] pvals32
3729-
cdef double[::1] pvals64
37303728
cdef double *pix
3731-
cdef bint err
3732-
cdef double max_sum
37333729
cdef int64_t *mnix
37343730
cdef int64_t ni
37353731
cdef np.broadcast it
@@ -3740,16 +3736,19 @@ cdef class Generator:
37403736
pvals, np.NPY_DOUBLE, 1, 1, np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
37413737
pix = <double*>np.PyArray_DATA(parr)
37423738
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
3743-
if isinstance(pvals, np.ndarray) and pvals.dtype == np.float32:
3744-
# Special case 32 bit floats
3745-
pvals32 = np.PyArray_GETCONTIGUOUS(pvals)
3746-
err = kahan_check(pvals32, <float>1.0000001)
3747-
else:
3748-
# All other original types are tested using 64 bit floats
3749-
pvals64 = parr
3750-
err = kahan_check(pvals64, 1.0 + 1e-12)
3751-
if err:
3752-
raise ValueError(f"sum(pvals[:-1]) > 1.0")
3739+
if kahan_sum(pix, d-1) > (1.0 + 1e-12):
3740+
msg = "sum(pvals[:-1]) > 1.0"
3741+
# When floating, but not float dtype, and close, improve the error
3742+
# 1.0001 works for float16 and float32
3743+
if (isinstance(pvals, np.ndarray) and
3744+
pvals.dtype != float and
3745+
np.issubdtype(pvals.dtype, np.floating) and
3746+
pvals.sum() < 1.0001):
3747+
msg += (". pvals has been cast to double before checking "
3748+
"the sum. Changes in precision when casting may "
3749+
"produce violations even if pvals.sum() <= 1 when "
3750+
"evaluated in its original dtype.")
3751+
raise ValueError(msg)
37533752

37543753
if np.PyArray_NDIM(on) != 0: # vector
37553754
check_array_constraint(on, 'n', CONS_NON_NEGATIVE)

numpy/random/tests/test_generator_mt19937.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,9 @@ def test_multinomial_pvals_float32(self):
147147
1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32)
148148
pvals = x / x.sum()
149149
random = Generator(MT19937(1432985819))
150-
result = random.multinomial(1, pvals)
151-
152-
random = Generator(MT19937(1432985819))
153-
pvals = pvals.astype(float)
154-
assert_raises(ValueError, random.multinomial, 1, pvals)
155-
pvals = pvals / pvals.sum()
156-
expected = random.multinomial(1, pvals)
157-
assert_array_equal(result, expected)
158-
150+
match = r"[\w\s]*pvals has been cast to double"
151+
with pytest.raises(ValueError, match=match):
152+
random.multinomial(1, pvals)
159153

160154
class TestMultivariateHypergeometric:
161155

0 commit comments

Comments
 (0)
0