8000 FIX GridSearchCV and HalvingGridSearchCV remove validation from __ini… · scikit-learn/scikit-learn@7774697 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7774697

Browse files
MrinalTyagiogriselthomasjpfan
authored
FIX GridSearchCV and HalvingGridSearchCV remove validation from __init__ and set_params (#21880)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent f853e78 commit 7774697

File tree

5 files changed

+51
-57
lines changed

5 files changed

+51
-57
lines changed

doc/whats_new/v1.1.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ Changelog
275275
splits failed. Similarly raise an error during grid-search when the fits for
276276
all the models and all the splits failed. :pr:`21026` by :user:`Loïc Estève <lesteve>`.
277277

278+
- |Fix| :class:`model_selection.GridSearchCV`, :class:`model_selection.HalvingGridSearchCV`
279+
now validate input parameters in `fit` instead of `__init__`.
280+
:pr:`21880` by :user:`Mrinal Tyagi <MrinalTyagi>`.
281+
278282
:mod:`sklearn.pipeline`
279283
.......................
280284

sklearn/model_selection/_search.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class ParameterGrid:
9494
def __init__(self, param_grid):
9595
if not isinstance(param_grid, (Mapping, Iterable)):
9696
raise TypeError(
97-
"Parameter grid is not a dict or a list ({!r})".format(param_grid)
97+
f"Parameter grid should be a dict or a list, got: {param_grid!r} of"
98+
f" type {type(param_grid).__name__}"
9899
)
99100

100101
if isinstance(param_grid, Mapping):
@@ -105,12 +106,26 @@ def __init__(self, param_grid):
105106
# check if all entries are dictionaries of lists
106107
for grid in param_grid:
107108
if not isinstance(grid, dict):
108-
raise TypeError("Parameter grid is not a dict ({!r})".format(grid))
109-
for key in grid:
110-
if not isinstance(grid[key], Iterable):
109+
raise TypeError(f"Parameter grid is not a dict ({grid!r})")
110+
for key, value in grid.items():
111+
if isinstance(value, np.ndarray) and value.ndim > 1:
112+
raise ValueError(
113+
f"Parameter array for {key!r} should be one-dimensional, got:"
114+
f" {value!r} with shape {value.shape}"
115+
)
116+
if isinstance(value, str) or not E7F5 isinstance(
117+
value, (np.ndarray, Sequence)
118+
):
111119
raise TypeError(
112-
"Parameter grid value is not iterable "
113-
"(key={!r}, value={!r})".format(key, grid[key])
120+
f"Parameter grid for parameter {key!r} needs to be a list or a"
121+
f" numpy array, but got {value!r} (of type "
122+
f"{type(value).__name__}) instead. Single values "
123+
"need to be wrapped in a list with one element."
124+
)
125+
if len(value) == 0:
126+
raise ValueError(
127+
f"Parameter grid for parameter {key!r} need "
128+
f"to be a non-empty sequence, got: {value!r}"
114129
)
115130

116131
self.param_grid = param_grid
@@ -244,9 +259,9 @@ class ParameterSampler:
244259
def __init__(self, param_distributions, n_iter, *, random_state=None):
245260
if not isinstance(param_distributions, (Mapping, Iterable)):
246261
raise TypeError(
247-
"Parameter distribution is not a dict or a list ({!r})".format(
248-
param_distributions
249-
)
262+
"Parameter distribution is not a dict or a list,"
263+
f" got: {param_distributions!r} of type "
264+
f"{type(param_distributions).__name__}"
250265
)
251266

252267
if isinstance(param_distributions, Mapping):
@@ -264,8 +279,8 @@ def __init__(self, param_distributions, n_iter, *, random_state=None):
264279
dist[key], "rvs"
265280
):
266281
raise TypeError(
267-
"Parameter value is not iterable "
268-
"or distribution (key={!r}, value={!r})".format(key, dist[key])
282+
f"Parameter grid for parameter {key!r} is not iterable "
283+
f"or a distribution (value={dist[key]})"
269284
)
270285
self.n_iter = n_iter
271286
self.random_state = random_state
@@ -321,30 +336,6 @@ def __len__(self):
321336
return self.n_iter
322337

323338

324-
def _check_param_grid(param_grid):
325-
if hasattr(param_grid, "items"):
326-
param_grid = [param_grid]
327-
328-
for p in param_grid:
329-
for name, v in p.items():
330-
if isinstance(v, np.ndarray) and v.ndim > 1:
331-
raise ValueError("Parameter array should be one-dimensional.")
332-
333-
if isinstance(v, str) or not isinstance(v, (np.ndarray, Sequence)):
334-
raise ValueError(
335-
"Parameter grid for parameter ({0}) needs to"
336-
" be a list or numpy array, but got ({1})."
337-
" Single values need to be wrapped in a list"
338-
" with one element.".format(name, type(v))
339-
)
340-
341-
if len(v) == 0:
342-
raise ValueError(
343-
"Parameter values for parameter ({0}) need "
344-
"to be a non-empty sequence.".format(name)
345-
)
346-
347-
348339
def _check_refit(search_cv, attr):
349340
if not search_cv.refit:
350341
raise AttributeError(
@@ -1385,7 +1376,6 @@ def __init__(
13851376
return_train_score=return_train_score,
13861377
)
13871378
self.param_grid = param_grid
1388-
_check_param_grid(param_grid)
13891379

13901380
def _run_search(self, evaluate_candidates):
13911381
"""Search all candidates in param_grid"""

sklearn/model_selection/_search_successive_halving.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from numbers import Integral
55

66
import numpy as np
7-
from ._search import _check_param_grid
87
from ._search import BaseSearchCV
98
from . import ParameterGrid, ParameterSampler
109
from ..base import is_classifier
@@ -714,7 +713,6 @@ def __init__(
714713
aggressive_elimination=aggressive_elimination,
715714
)
716715
self.param_grid = param_grid
717-
_check_param_grid(self.param_grid)
718716

719717
def _generate_candidate_params(self):
720718
return ParameterGrid(self.param_grid)

sklearn/model_selection/tests/test_search.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,13 @@ def assert_grid_iter_equals_getitem(grid):
133133
@pytest.mark.parametrize(
134134
"input, error_type, error_message",
135135
[
136-
(0, TypeError, r"Parameter .* is not a dict or a list \(0\)"),
136+
(0, TypeError, r"Parameter .* a dict or a list, got: 0 of type int"),
137137
([{"foo": [0]}, 0], TypeError, r"Parameter .* is not a dict \(0\)"),
138138
(
139139
{"foo": 0},
140140
TypeError,
141-
"Parameter.* value is not iterable .*" r"\(key='foo', value=0\)",
141+
r"Parameter (grid|distribution) for parameter 'foo' (is not|needs to be) "
142+
r"(a list or a numpy array|iterable or a distribution).*",
142143
),
143144
],
144145
)
@@ -440,40 +441,43 @@ def test_grid_search_when_param_grid_includes_range():
440441

441442

442443
def test_grid_search_bad_param_grid():
444+
X, y = make_classification(n_samples=10, n_features=5, random_state=0)
443445
param_dict = {"C": 1}
444446
clf = SVC(gamma="auto")
445447
error_msg = re.escape(
446-
"Parameter grid for parameter (C) needs to"
447-
" be a list or numpy array, but got (<class 'int'>)."
448-
" Single values need to be wrapped in a list"
449-
" with one element."
448+
"Parameter grid for parameter 'C' needs to be a list or "
449+
"a numpy array, but got 1 (of type int) instead. Single "
450+
"values need to be wrapped in a list with one element."
450451
)
451-
with pytest.raises(ValueError, match=error_msg):
452-
GridSearchCV(clf, param_dict)
452+
search = GridSearchCV(clf, param_dict)
453+
with pytest.raises(TypeError, match=error_msg):
454+
search.fit(X, y)
453455

454456
param_dict = {"C": []}
455457
clf = SVC()
456458
error_msg = re.escape(
457-
"Parameter values for parameter (C) need to be a non-empty sequence."
459+
"Parameter grid for parameter 'C' need to be a non-empty sequence, got: []"
458460
)
461+
search = GridSearchCV(clf, param_dict)
459462
with pytest.raises(ValueError, match=error_msg):
460-
GridSearchCV(clf, param_dict)
463+
search.fit(X, y)
461464

462465
param_dict = {"C": "1,2,3"}
463466
clf = SVC(gamma="auto")
464467
error_msg = re.escape(
465-
"Parameter grid for parameter (C) needs to"
466-
" be a list or numpy array, but got (<class 'str'>)."
467-
" Single values need to be wrapped in a list"
468-
" with one element."
468+
"Parameter grid for parameter 'C' needs to be a list or a numpy array, "
469+
"but got '1,2,3' (of type str) instead. Single values need to be "
470+
"wrapped in a list with one element."
469471
)
470-
with pytest.raises(ValueError, match=error_msg):
471-
GridSearchCV(clf, param_dict)
472+
search = GridSearchCV(clf, param_dict)
473+
with pytest.raises(TypeError, match=error_msg):
474+
search.fit(X, y)
472475

473476
param_dict = {"C": np.ones((3, 2))}
474477
clf = SVC()
478+
search = GridSearchCV(clf, param_dict)
475479
with pytest.raises(ValueError):
476-
GridSearchCV(clf, param_dict)
480+
search.fit(X, y)
477481

478482

479483
def test_grid_search_sparse():

sklearn/tests/test_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,6 @@ def test_transformers_get_feature_names_out(transformer):
425425
VALIDATE_ESTIMATOR_INIT = [
426426
"ColumnTransformer",
427427
"FeatureUnion",
428-
"GridSearchCV",
429-
"HalvingGridSearchCV",
430428
"SGDOneClassSVM",
431429
"TheilSenRegressor",
432430
"TweedieRegressor",

0 commit comments

Comments
 (0)
0