8000 MAINT Parameters validation for utils.gen_batches (#25864) · jeremiedbb/scikit-learn@f4fbb28 · GitHub
[go: up one dir, main page]

Skip to content

Commit f4fbb28

Browse files
MAINT Parameters validation for utils.gen_batches (scikit-learn#25864)
1 parent 9d55835 commit f4fbb28

File tree

3 files changed

+9
-20
lines changed

3 files changed

+9
-20
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def _check_function_param_validation(
189189
"sklearn.model_selection.train_test_split",
190190
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
191191
"sklearn.svm.l1_min_c",
192+
"sklearn.utils.gen_batches",
192193
]
193194

194195

sklearn/utils/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from .. import get_config
4141
from ._bunch import Bunch
42+
from ._param_validation import validate_params, Interval
4243

4344

4445
# Do not deprecate parallel_backend and register_parallel_backend as they are
@@ -725,6 +726,13 @@ def _chunk_generator(gen, chunksize):
725726
return
726727

727728

729+
@validate_params(
730+
{
731+
"n": [Interval(numbers.Integral, 1, None, closed="left")],
732+
"batch_size": [Interval(numbers.Integral, 1, None, closed="left")],
733+
"min_batch_size": [Interval(numbers.Integral, 0, None, closed="left")],
734+
}
735+
)
728736
def gen_batches(n, batch_size, *, min_batch_size=0):
729737
"""Generator to create slices containing `batch_size` elements from 0 to `n`.
730738
@@ -762,12 +770,6 @@ def gen_batches(n, batch_size, *, min_batch_size=0):
762770
>>> list(gen_batches(7, 3, min_batch_size=2))
763771
[slice(0, 3, None), slice(3, 7, None)]
764772
"""
765-
if not isinstance(batch_size, numbers.Integral):
766-
raise TypeError(
767-
"gen_batches got batch_size=%s, must be an integer" % batch_size
768-
)
769-
if batch_size <= 0:
770-
raise ValueError("gen_batches got batch_size=%s, must be positive" % batch_size)
771773
start = 0
772774
for _ in range(int(n // batch_size)):
773775
end = start + batch_size

sklearn/utils/tests/test_utils.py< EBE2 /h3>

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from sklearn.utils import check_random_state
1818
from sklearn.utils import _determine_key_type
1919
from sklearn.utils import deprecated
20-
from sklearn.utils import gen_batches
2120
from sklearn.utils import _get_column_indices
2221
from sklearn.utils import resample
2322
from sklearn.utils import safe_mask
@@ -56,19 +55,6 @@ def test_make_rng():
5655
check_random_state("some invalid seed")
5756

5857

59-
def test_gen_batches():
60-
# Make sure gen_batches errors on invalid batch_size
61-
62-
assert_array_equal(list(gen_batches(4, 2)), [slice(0, 2, None), slice(2, 4, None)])
63-
msg_zero = "gen_batches got batch_size=0, must be positive"
64-
with pytest.raises(ValueError, match=msg_zero):
65-
next(gen_batches(4, 0))
66-
67-
msg_float = "gen_batches got batch_size=0.5, must be an integer"
68-
with pytest.raises(TypeError, match=msg_float):
69-
next(gen_batches(4, 0.5))
70-
71-
7258
def test_deprecated():
7359
# Test whether the deprecated decorator issues appropriate warnings
7460
# Copied almost verbatim from https://docs.python.org/library/warnings.html

0 commit comments

Comments
 (0)
0