8000 ENH/BUG: Allow multinomial to check pvals with other float types · bashtage/numpy@c1e3e11 · GitHub
[go: up one dir, main page]

Skip to content

Commit c1e3e11

Browse files
committed
ENH/BUG: Allow multinomial to check pvals with other float types
Add additional check when original input is an array that does not have dtype double closes numpy#8317
1 parent 0eb9f54 commit c1e3e11

File tree

2 files changed

+25
-2
lines changed
  • tests
  • 2 files changed

    +25
    -2
    lines changed

    numpy/random/_generator.pyx

    Lines changed: 11 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -3726,6 +3726,7 @@ cdef class Generator:
    37263726
    cdef np.npy_intp d, i, sz, offset
    37273727
    cdef np.ndarray parr, mnarr, on, temp_arr
    37283728
    cdef double *pix
    3729+
    cdef double max_sum
    37293730
    cdef int64_t *mnix
    37303731
    cdef int64_t ni
    37313732
    cdef np.broadcast it
    @@ -3736,8 +3737,16 @@ cdef class Generator:
    37363737
    pvals, np.NPY_DOUBLE, 1, 1, np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
    37373738
    pix = <double*>np.PyArray_DATA(parr)
    37383739
    check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
    3739-
    if kahan_sum(pix, d-1) > (1.0 + 1e-12):
    3740-
    raise ValueError("sum(pvals[:-1]) > 1.0")
    3740+
    max_sum = 1.0 + 1e-12
    3741+
    if kahan_sum(pix, d-1) > max_sum:
    3742+
    # Further checks to handle case where the pvals is an array with
    3743+
    # a dtype that differs from double. Comparison is slow, but should
    3744+
    # almost never be hit when pvals is valid
    3745+
    if (not isinstance(pvals, np.ndarray)
    3746+
    or pvals.dtype == float
    3747+
    or not np.issubdtype(pvals.dtype, np.floating)
    3748+
    or pvals[:-1].sum() > pvals.dtype.type(max_sum)):
    3749+
    raise ValueError(f"sum(pvals[:-1]) > 1.0")
    37413750

    37423751
    if np.PyArray_NDIM(on) != 0: # vector
    37433752
    check_array_constraint(on, 'n', CONS_NON_NEGATIVE)

    numpy/random/tests/test_generator_mt19937.py

    Lines changed: 14 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -142,6 +142,20 @@ def test_multidimensional_pvals(self):
    142142
    assert_raises(ValueError, random.multinomial, 10, [[[0], [1]], [[1], [0]]])
    143143
    assert_raises(ValueError, random.multinomial, 10, np.array([[0, 1], [1, 0]]))
    144144

    145+
    def test_multinomial_pvals_float32(self):
    146+
    x = np.array([9.9e-01, 9.9e-01, 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09,
    147+
    1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32)
    148+
    pvals = x / x.sum()
    149+
    random = Generator(MT19937(1432985819))
    150+
    result = random.multinomial(1, pvals)
    151+
    152+
    random = Generator(MT19937(1432985819))
    153+
    pvals = pvals.astype(float)
    154+
    assert_raises(ValueError, random.multinomial, 1, pvals)
    155+
    pvals = pvals / pvals.sum()
    156+
    expected = random.multinomial(1, pvals)
    157+
    assert_array_equal(result, expected)
    158+
    145159

    146160
    class TestMultivariateHypergeometric:
    147161

    0 commit comments

    Comments
     (0)
    0