@@ -3726,6 +3726,7 @@ cdef class Generator:
3726
3726
cdef np .npy_intp d , i , sz , offset
3727
3727
cdef np .ndarray parr , mnarr , on , temp_arr
3728
3728
cdef double * pix
3729
+ cdef double max_sum
3729
3730
cdef int64_t * mnix
3730
3731
cdef int64_t ni
3731
3732
cdef np .broadcast it
@@ -3736,8 +3737,16 @@ cdef class Generator:
3736
3737
pvals , np .NPY_DOUBLE , 1 , 1 , np .NPY_ARRAY_ALIGNED | np .NPY_ARRAY_C_CONTIGUOUS )
3737
3738
pix = < double * > np .PyArray_DATA (parr )
3738
3739
check_array_constraint (parr , 'pvals' , CONS_BOUNDED_0_1 )
3739
- if kahan_sum (pix , d - 1 ) > (1.0 + 1e-12 ):
3740
- raise ValueError ("sum(pvals[:-1]) > 1.0" )
3740
+ max_sum = 1.0 + 1e-12
3741
+ if kahan_sum (pix , d - 1 ) > max_sum :
3742
+ # Further checks to handle case where the pvals is an array with
3743
+ # a dtype that differs from double. Comparison is slow, but should
3744
+ # almost never be hit when pvals is valid
3745
+ if (not isinstance (pvals , np .ndarray )
3746
+ or pvals .dtype == float
3747
+ or not np .issubdtype (pvals .dtype , np .floating )
3748
+ or pvals [:- 1 ].sum () > pvals .dtype .type (max_sum )):
3749
+ raise ValueError (f"sum(pvals[:-1]) > 1.0" )
3741
3750
3742
3751
if np .PyArray_NDIM (on ) != 0 : # vector
3743
3752
check_array_constraint (on , 'n' , CONS_NON_NEGATIVE )
0 commit comments