diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index ea1b64cda62f9..d2a4e74fc5e4f 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -172,6 +172,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", + "sklearn.utils.gen_batches", ] diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 923c08d44c6f4..0b3119c0bfa03 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -39,6 +39,7 @@ ) from .. import get_config from ._bunch import Bunch +from ._param_validation import validate_params, Interval # Do not deprecate parallel_backend and register_parallel_backend as they are @@ -725,6 +726,13 @@ def _chunk_generator(gen, chunksize): return +@validate_params( + { + "n": [Interval(numbers.Integral, 1, None, closed="left")], + "batch_size": [Interval(numbers.Integral, 1, None, closed="left")], + "min_batch_size": [Interval(numbers.Integral, 0, None, closed="left")], + } +) def gen_batches(n, batch_size, *, min_batch_size=0): """Generator to create slices containing `batch_size` elements from 0 to `n`. @@ -762,12 +770,6 @@ def gen_batches(n, batch_size, *, min_batch_size=0): >>> list(gen_batches(7, 3, min_batch_size=2)) [slice(0, 3, None), slice(3, 7, None)] """ - if not isinstance(batch_size, numbers.Integral): - raise TypeError( - "gen_batches got batch_size=%s, must be an integer" % batch_size - ) - if batch_size <= 0: - raise ValueError("gen_batches got batch_size=%s, must be positive" % batch_size) start = 0 for _ in range(int(n // batch_size)): end = start + batch_size diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 848985f267c92..a000394bbee28 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -17,7 +17,6 @@ from sklearn.utils import check_random_state from sklearn.utils import _determine_key_type from sklearn.utils import deprecated -from sklearn.utils import gen_batches from sklearn.utils import _get_column_indices from sklearn.utils import resample from sklearn.utils import safe_mask @@ -56,19 +55,6 @@ def test_make_rng(): check_random_state("some invalid seed") -def test_gen_batches(): - # Make sure gen_batches errors on invalid batch_size - - assert_array_equal(list(gen_batches(4, 2)), [slice(0, 2, None), slice(2, 4, None)]) - msg_zero = "gen_batches got batch_size=0, must be positive" - with pytest.raises(ValueError, match=msg_zero): - next(gen_batches(4, 0)) - - msg_float = "gen_batches got batch_size=0.5, must be an integer" - with pytest.raises(TypeError, match=msg_float): - next(gen_batches(4, 0.5)) - - def test_deprecated(): # Test whether the deprecated decorator issues appropriate warnings # Copied almost verbatim from https://docs.python.org/library/warnings.html