From 3b67a2acca7d53585b7bc30a8f839f1bf3040d78 Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Wed, 17 Mar 2021 12:59:51 +0000 Subject: [PATCH] ENH: Improve the exception for default low in Generator.integers Improve the exception when low is 0 in case the single input form was used. closes #14333 --- numpy/random/_bounded_integers.pyx.in | 27 ++++++++++++-------- numpy/random/tests/test_generator_mt19937.py | 15 +++++++++++ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/numpy/random/_bounded_integers.pyx.in b/numpy/random/_bounded_integers.pyx.in index 9f46685d3258..7eb6aff1e9f1 100644 --- a/numpy/random/_bounded_integers.pyx.in +++ b/numpy/random/_bounded_integers.pyx.in @@ -8,6 +8,7 @@ __all__ = [] np.import_array() + cdef extern from "numpy/random/distributions.h": # Generate random numbers in closed interval [off, off + rng]. uint64_t random_bounded_uint64(bitgen_t *bitgen_state, @@ -51,6 +52,17 @@ cdef extern from "numpy/random/distributions.h": np.npy_bool *out) nogil +cdef object format_bounds_error(bint closed, object low): + # Special case low == 0 to provide a better exception for users + # since low = 0 is the default single-argument case. + if not np.any(low): + comp = '<' if closed else '<=' + return f'high {comp} 0' + else: + comp = '>' if closed else '>=' + return f'low {comp} high' + + {{ py: type_info = (('uint32', 'uint32', 'uint64', 'NPY_UINT64', 0, 0, 0, '0X100000000ULL'), @@ -99,8 +111,7 @@ cdef object _rand_{{nptype}}_broadcast(np.ndarray low, np.ndarray high, object s if np.any(high_comp(high_arr, {{ub}})): raise ValueError('high is out of bounds for {{nptype}}') if np.any(low_high_comp(low_arr, high_arr)): - comp = '>' if closed else '>=' - raise ValueError('low {comp} high'.format(comp=comp)) + raise ValueError(format_bounds_error(closed, low_arr)) low_arr = np.PyArray_FROM_OTF(low, np.{{npctype}}, np.NPY_ALIGNED | np.NPY_FORCECAST) high_arr = np.PyArray_FROM_OTF(high, np.{{npctype}}, np.NPY_ALIGNED | np.NPY_FORCECAST) @@ -173,8 +184,7 @@ cdef object _rand_{{nptype}}_broadcast(object low, object high, object size, # Avoid object dtype path if already an integer high_lower_comp = np.less if closed else np.less_equal if np.any(high_lower_comp(high_arr, {{lb}})): - comp = '>' if closed else '>=' - raise ValueError('low {comp} high'.format(comp=comp)) + raise ValueError(format_bounds_error(closed, low_arr)) high_m1 = high_arr if closed else high_arr - dt.type(1) if np.any(np.greater(high_m1, {{ub}})): raise ValueError('high is out of bounds for {{nptype}}') @@ -191,13 +201,11 @@ cdef object _rand_{{nptype}}_broadcast(object low, object high, object size, if closed_upper > {{ub}}: raise ValueError('high is out of bounds for {{nptype}}') if closed_upper < {{lb}}: - comp = '>' if closed else '>=' - raise ValueError('low {comp} high'.format(comp=comp)) + raise ValueError(format_bounds_error(closed, low_arr)) highm1_data[i] = <{{nptype}}_t>closed_upper if np.any(np.greater(low_arr, highm1_arr)): - comp = '>' if closed else '>=' - raise ValueError('low {comp} high'.format(comp=comp)) + raise ValueError(format_bounds_error(closed, low_arr)) high_arr = highm1_arr low_arr = np.PyArray_FROM_OTF(low, np.{{npctype}}, np.NPY_ALIGNED | np.NPY_FORCECAST) @@ -316,8 +324,7 @@ cdef object _rand_{{nptype}}(object low, object high, object size, if high > {{ub}}: raise ValueError("high is out of bounds for {{nptype}}") if low > high: # -1 already subtracted, closed interval - comp = '>' if closed else '>=' - raise ValueError('low {comp} high'.format(comp=comp)) + raise ValueError(format_bounds_error(closed, low)) rng = <{{utype}}_t>(high - low) off = <{{utype}}_t>(<{{nptype}}_t>low) diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 446b350dd8de..310545e0d8ea 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -2562,3 +2562,18 @@ def test_ragged_shuffle(): gen = Generator(MT19937(0)) assert_no_warnings(gen.shuffle, seq) assert seq == [1, [], []] + + +@pytest.mark.parametrize("high", [-2, [-2]]) +@pytest.mark.parametrize("endpoint", [True, False]) +def test_single_arg_integer_exception(high, endpoint): + # GH 14333 + gen = Generator(MT19937(0)) + msg = 'high < 0' if endpoint else 'high <= 0' + with pytest.raises(ValueError, match=msg): + gen.integers(high, endpoint=endpoint) + msg = 'low > high' if endpoint else 'low >= high' + with pytest.raises(ValueError, match=msg): + gen.integers(-1, high, endpoint=endpoint) + with pytest.raises(ValueError, match=msg): + gen.integers([-1], high, endpoint=endpoint)