8000 Updates for GridSearchCV · scikit-learn/scikit-learn@1f581ff · GitHub
[go: up one dir, main page]

Skip to content

Commit 1f581ff

Browse files
committed
Updates for GridSearchCV
1 parent b6abb12 commit 1f581ff

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

sklearn/model_selection/_search.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,9 @@ def _select_best_index(refit, refit_metric, results):
757757
return best_index
758758

759759
def fit(self, X, y=None, *, groups=None, **fit_params):
760+
return self._fit(X, y=y, groups=groups, **fit_params)
761+
762+
def _fit(self, X, y=None, *, groups=None, **fit_params):
760763
"""Run fit with all sets of parameters.
761764
762765
Parameters
@@ -1386,9 +1389,12 @@ def __init__(
13861389
)
13871390
self.param_grid = param_grid
13881391

1392+
def _fit(self, X, y=None, *, groups=None, **fit_params):
1393+
_check_param_grid(self.param_grid)
1394+
return super()._fit(X, y=y, groups=groups, **fit_params)
1395+
13891396
def _run_search(self, evaluate_candidates):
13901397
"""Search all candidates in param_grid"""
1391-
_check_param_grid(self.param_grid)
13921398
evaluate_candidates(ParameterGrid(self.param_grid))
13931399

13941400

sklearn/model_selection/_search_successive_halving.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,6 @@ def __init__(
716716
self.param_grid = param_grid
717717

718718
def _generate_candidate_params(self):
719-
_check_param_grid(self.param_grid)
720719
return ParameterGrid(self.param_grid)
721720

722721

sklearn/model_selection/tests/test_search.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ def test_grid_search_when_param_grid_includes_range():
440440

441441

442442
def test_grid_search_bad_param_grid():
443+
X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
443444
param_dict = {"C": 1}
444445
clf = SVC(gamma="auto")
445446
error_msg = re.escape(
@@ -449,15 +450,15 @@ def test_grid_search_bad_param_grid():
449450
" with one element."
450451
)
451452
with pytest.raises(ValueError, match=error_msg):
452-
GridSearchCV(clf, param_dict)
453+
GridSearchCV(clf, param_dict).fit(X_, y_)
453454

454455
param_dict = {"C": []}
455456
clf = SVC()
456457
error_msg = re.escape(
457458
"Parameter values for parameter (C) need to be a non-empty sequence."
458459
)
459460
with pytest.raises(ValueError, match=error_msg):
460-
GridSearchCV(clf, param_dict)
461+
GridSearchCV(clf, param_dict).fit(X_, y_)
461462

462463
param_dict = {"C": "1,2,3"}
463464
clf = SVC(gamma="auto")
@@ -468,12 +469,12 @@ def test_grid_search_bad_param_grid():
468469
" with one element."
469470
)
470471
with pytest.raises(ValueError, match=error_msg):
471-
GridSearchCV(clf, param_dict)
472+
GridSearchCV(clf, param_dict).fit(X_, y_)
472473

473474
param_dict = {"C": np.ones((3, 2))}
474475
clf = SVC()
475476
with pytest.raises(ValueError):
476-
GridSearchCV(clf, param_dict)
477+
GridSearchCV(clf, param_dict).fit(X_, y_)
477478

478479

479480
def test_grid_search_sparse():

0 commit comments

Comments
 (0)
0