8000 ENH: Allow longdouble probabilities · bashtage/numpy@f447348 · GitHub
[go: up one dir, main page]

Skip to content

Commit f447348

Browse files
bashtageKevin Sheppard
authored andcommitted
ENH: Allow longdouble probabilities
Force cast longdouble to double to allow longdoubles to be used as probabilities closes numpy#6132
1 parent da6952f commit f447348

File tree

6 files changed

+65
-16
lines changed

6 files changed

+65
-16
lines changed

numpy/random/_common.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ cdef double kahan_sum(double *darr, np.npy_intp n)
6767
cdef inline double uint64_to_double(uint64_t rnd) nogil:
6868
return (rnd >> 11) * (1.0 / 9007199254740992.0)
6969

70+
cdef np.ndarray convert_floating(object prob)
71+
7072
cdef object double_fill(void *func, bitgen_t *state, object size, object lock, object out)
7173

7274
cdef object float_fill(void *func, bitgen_t *state, object size, object lock, object out)

numpy/random/_common.pyx

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,31 @@ cdef check_output(object out, object dtype, object size, bint require_c_array):
272272
raise ValueError('size must match out.shape when used together')
273273

274274

275+
cdef np.ndarray convert_floating(object prob):
276+
"""
277+
Convert array-like floating to float64
278+
279+
Parameters
280+
----------
281+
prob : array_like
282+
Probabilities to convert
283+
284+
Returns
285+
-------
286+
prob_arr : ndarray
287+
An double array that is aligned and c-contiguous. If prob is an
288+
ndarray with a longdouble dtype, force casts to double. Otherwise
289+
uses safe casting.
290+
"""
291+
cdef int requirements = np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS
292+
293+
if np.PyArray_Check(prob) and np.issubdtype(prob.dtype, np.longdouble):
294+
# Force cast to allow longdouble, others are safe
295+
requirements |= np.NPY_ARRAY_FORCECAST
296+
297+
return <np.ndarray>np.PyArray_FROM_OTF(prob, np.NPY_DOUBLE, requirements)
298+
299+
275300
cdef object double_fill(void *func, bitgen_t *state, object size, object lock, object out):
276301
cdef random_double_fill random_func = (<random_double_fill>func)
277302
cdef double out_val

numpy/random/_generator.pyx

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ from ._common cimport (POISSON_LAM_MAX, CONS_POSITIVE, CONS_NONE,
2727
CONS_GT_1, CONS_POSITIVE_NOT_NAN, CONS_POISSON,
2828
double_fill, cont, kahan_sum, cont_broadcast_3, float_fill, cont_f,
2929
check_array_constraint, check_constraint, disc, discrete_broadcast_iii,
30-
validate_output_shape
30+
validate_output_shape, convert_floating
3131
)
3232

3333
cdef extern from "numpy/arrayobject.h":
@@ -700,6 +700,7 @@ cdef class Generator:
700700
cdef np.npy_intp j
701701
cdef uint64_t set_size, mask
702702
cdef uint64_t[::1] hash_set
703+
cdef int requirements
703704
# Format and Verify input
704705
a_original = a
705706
a = np.array(a, copy=False)
@@ -721,14 +722,12 @@ cdef class Generator:
721722

722723
if p is not None:
723724
d = len(p)
724-
725725
atol = np.sqrt(np.finfo(np.float64).eps)
726-
if isinstance(p, np.ndarray):
727-
if np.issubdtype(p.dtype, np.floating):
728-
atol = max(atol, np.sqrt(np.finfo(p.dtype).eps))
726+
if isinstance(p, np.ndarray) and np.issubdtype(p.dtype, np.floating):
727+
# Force cast to allow float128, others are safe
728+
atol = max(atol, np.sqrt(np.finfo(p.dtype).eps))
729729

730-
p = <np.ndarray>np.PyArray_FROM_OTF(
731-
p, np.NPY_DOUBLE, np.NPY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
730+
p = <np.ndarray>convert_floating(p)
732731
pix = <double*>np.PyArray_DATA(p)
733732

734733
if p.ndim != 1:
@@ -2898,7 +2897,7 @@ cdef class Generator:
28982897
cdef np.int64_t *randoms_data
28992898
cdef np.broadcast it
29002899

2901-
p_arr = <np.ndarray>np.PyArray_FROM_OTF(p, np.NPY_DOUBLE, np.NPY_ALIGNED)
2900+
p_arr = <np.ndarray>convert_floating(p)
29022901
is_scalar = is_scalar and np.PyArray_NDIM(p_arr) == 0
29032902
n_arr = <np.ndarray>np.PyArray_FROM_OTF(n, np.NPY_INT64, np.NPY_ALIGNED)
29042903
is_scalar = is_scalar and np.PyArray_NDIM(n_arr) == 0
@@ -3754,8 +3753,9 @@ cdef class Generator:
37543753

37553754
d = len(pvals)
37563755
on = <np.ndarray>np.PyArray_FROM_OTF(n, np.NPY_INT64, np.NPY_ALIGNED)
3757-
parr = <np.ndarray>np.PyArray_FROMANY(
3758-
pvals, np.NPY_DOUBLE, 1, 1, np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
3756+
parr = <np.ndarray>convert_floating(pvals)
3757+
if np.PyArray_NDIM(parr) != 1:
3758+
raise ValueError("pvals must be a 1d array")
37593759
pix = <double*>np.PyArray_DATA(parr)
37603760
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
37613761
if kahan_sum(pix, d-1) > (1.0 + 1e-12):

numpy/random/mtrand.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ from ._common cimport (POISSON_LAM_MAX, CONS_POSITIVE, CONS_NONE,
2323
CONS_GT_1, LEGACY_CONS_POISSON,
2424
double_fill, cont, kahan_sum, cont_broadcast_3,
2525
check_array_constraint, check_constraint, disc, discrete_broadcast_iii,
26-
validate_output_shape
26+
validate_output_shape, convert_floating
2727
)
2828

2929
cdef extern from "numpy/random/distributions.h":
@@ -922,8 +922,7 @@ cdef class RandomState:
922922
if np.issubdtype(p.dtype, np.floating):
923923
atol = max(atol, np.sqrt(np.finfo(p.dtype).eps))
924924

925-
p = <np.ndarray>np.PyArray_FROM_OTF(
926-
p, np.NPY_DOUBLE, np.NPY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
925+
p = <np.ndarray>convert_floating(p)
927926
pix = <double*>np.PyArray_DATA(p)
928927

929928
if p.ndim != 1:
@@ -3377,7 +3376,7 @@ cdef class RandomState:
33773376
cdef long *randoms_data
33783377
cdef np.broadcast it
33793378

3380-
p_arr = <np.ndarray>np.PyArray_FROM_OTF(p, np.NPY_DOUBLE, np.NPY_ALIGNED)
3379+
p_arr = <np.ndarray>convert_floating(p)
33813380
is_scalar = is_scalar and np.PyArray_NDIM(p_arr) == 0
33823381
n_arr = <np.ndarray>np.PyArray_FROM_OTF(n, np.NPY_LONG, np.NPY_ALIGNED)
33833382
is_scalar = is_scalar and np.PyArray_NDIM(n_arr) == 0
@@ -4229,8 +4228,9 @@ cdef class RandomState:
42294228
cdef long ni
42304229

42314230
d = len(pvals)
4232-
parr = <np.ndarray>np.PyArray_FROMANY(
4233-
pvals, np.NPY_DOUBLE, 1, 1, np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
4231+
parr = <np.ndarray>convert_floating(pvals)
4232+
if np.PyArray_NDIM(parr) != 1:
4233+
raise ValueError("pvals must be a 1d array")
42344234
pix = <double*>np.PyArray_DATA(parr)
42354235
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
42364236
if kahan_sum(pix, d-1) > (1.0 + 1e-12):

numpy/random/tests/test_generator_mt19937.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,17 @@ def test_choice_large_sample(self):
921921
res = hashlib.sha256(actual.view(np.int8)).hexdigest()
922922
assert_(choice_hash == res)
923923

924+
def test_choice_longdouble(self):
925+
random = Generator(MT19937(self.seed))
926+
p = np.array([.4, .4, .2])
927+
actual = random.choice([0, 1, 2], size=2, p=p)
928+
929+
random = Generator(MT19937(self.seed))
930+
p = np.array([.4, .4, .2], dtype=np.longdouble)
931+
actual_128 = random.choice([0, 1, 2], size=2, p=p)
932+
933+
assert_equal(actual, actual_128)
934+
924935
def test_bytes(self):
925936
random = Generator(MT19937(self.seed))
926937
actual = random.bytes(10)

numpy/random/tests/test_randomstate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,17 @@ def test_choice_p_non_contiguous(self):
660660
contig = random.choice(5, 3, p=np.ascontiguousarray(p[::2]))
661661
assert_array_equal(non_contig, contig)
662662

663+
def test_choice_longdouble(self):
664+
random.seed(self.seed)
665+
p = np.array([.4, .4, .2])
666+
actual = random.choice([0, 1, 2], size=2, p=p)
667+
668+
random.seed(self.seed)
669+
p = np.array([.4, .4, .2], dtype=np.longdouble)
670+
actual_128 = random.choice([0, 1, 2], size=2, p=p)
671+
672+
assert_equal(actual, actual_128)
673+
663674
def test_bytes(self):
664675
random.seed(self.seed)
665676
actual = random.bytes(10)

0 commit comments

Comments
 (0)
0