8000 Port error to RandomState · numpy/numpy@b1015ad · GitHub
[go: up one dir, main page]

Skip to content

Commit b1015ad

Browse files
author
Kevin Sheppard
committed
Port error to RandomState
1 parent e900be2 commit b1015ad

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

numpy/random/mtrand.pyx

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4232,7 +4232,20 @@ cdef class RandomState:
42324232
pix = <double*>np.PyArray_DATA(parr)
42334233
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
42344234
if kahan_sum(pix, d-1) > (1.0 + 1e-12):
4235-
raise ValueError("sum(pvals[:-1]) > 1.0")
4235+
# When floating, but not float dtype, and close, improve the error
4236+
# 1.0001 works for float16 and float32
4237+
if (isinstance(pvals, np.ndarray)
4238+
and np.issubdtype(pvals.dtype, np.floating)
4239+
and pvals.dtype != float
4240+
and pvals.sum() < 1.0001):
4241+
msg = ("sum(pvals[:-1].astype(np.float64)) > 1.0. The pvals "
4242+
"array is cast to 64-bit floating point prior to "
4243+
"checking the sum. Precision changes when casting may "
4244+
"cause problems even if the sum of the original pvals "
4245+
"is valid.")
4246+
else:
4247+
msg = "sum(pvals[:-1]) > 1.0"
4248+
raise ValueError(msg)
42364249

42374250
if size is None:
42384251
shape = (d,)

numpy/random/tests/test_generator_mt19937.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_multinomial_pvals_float32(self):
147147
1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32)
148148
pvals = x / x.sum()
149149
random = Generator(MT19937(1432985819))
150-
match = r"[\w\s]*pvals are cast to 64-bit floating"
150+
match = r"[\w\s]*pvals array is cast to 64-bit floating"
151151
with pytest.raises(ValueError, match=match):
152152
random.multinomial(1, pvals)
153153

numpy/random/tests/test_randomstate.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,14 @@ def test_p_non_contiguous(self):
167167
contig = random.multinomial(100, pvals=np.ascontiguousarray(pvals))
168168
assert_array_equal(non_contig, contig)
169169

170+
def test_multinomial_pvals_float32(self):
171+
x = np.array([9.9e-01, 9.9e-01, 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09,
172+
1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32)
173+
pvals = x / x.sum()
174+
match = r"[\w\s]*pvals array is cast to 64-bit floating"
175+
with pytest.raises(ValueError, match=match):
176+
random.multinomial(1, pvals)
177+
170178

171179
class TestSetState:
172180
def setup(self):

0 commit comments

Comments
 (0)
0