8000 BaseSearchCV._run_search raises NotImplementedError instead of being … · sjtrny/scikit-learn@239482f · GitHub
[go: up one dir, main page]

Skip to content

Commit 239482f

Browse files
adrinjalaliamueller
authored andcommitted
BaseSearchCV._run_search raises NotImplementedError instead of being an abstractmethod (scikit-learn#12182)
* _run_search raises NotImplementedError instead of being and abstractmethod * add error message * test for a BaseSearchCV child w/o a _run_search * make the test python2 compatible, still in 0.20 zone. * specify cv in tests not to trigger the related FutureWarning * PEP8
1 parent a358d7d commit 239482f

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

sklearn/model_selection/_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,10 @@ def classes_(self):
580580
self._check_is_fitted("classes_")
581581
return self.best_estimator_.classes_
582582

583-
@abstractmethod
584583
def _run_search(self, evaluate_candidates):
585584
"""Repeatedly calls `evaluate_candidates` to conduct a search.
586585
587-
This method, implemented in sub-classes, makes it is possible to
586+
This method, implemented in sub-classes, makes it possible to
588587
customize the the scheduling of evaluations: GridSearchCV and
589588
RandomizedSearchCV schedule evaluations for their whole parameter
590589
search space at once but other more sequential approaches are also
@@ -613,6 +612,7 @@ def _run_search(self, evaluate_candidates):
613612
if score[0] < score[1]:
614613
evaluate_candidates([{'C': 0.1}])
615614
"""
615+
raise NotImplementedError("_run_search not implemented.")
616616

617617
def fit(self, X, y=None, groups=None, **fit_params):
618618
"""Run fit with all sets of parameters.

sklearn/model_selection/tests/test_search.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def test_parameter_grid():
182182

183183
@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
184184
@pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22
185-
186185
def test_grid_search():
187186
# Test that the best estimator contains the right value for foo_param
188187
clf = MockClassifier()
@@ -1678,6 +1677,27 @@ def _run_search(self, evaluate):
16781677
"Attribute %s not equal" % attr
16791678

16801679

1680+
def test__custom_fit_no_run_search():
1681+
class NoRunSearchSearchCV(BaseSearchCV):
1682+
def __init__(self, estimator, **kwargs):
1683+
super(NoRunSearchSearchCV, self).__init__(estimator, **kwargs)
1684+
1685+
def fit(self, X, y=None, groups=None, **fit_params):
1686+
return self
1687+
1688+
# this should not raise any exceptions
1689+
NoRunSearchSearchCV(SVC(), cv=5).fit(X, y)
1690+
1691+
class BadSearchCV(BaseSearchCV):
1692+
def __init__(self, estimator, **kwargs):
1693+
super(BadSearchCV, self).__init__(estimator, **kwargs)
1694+
1695+
with pytest.raises(NotImplementedError,
1696+
match="_run_search not implemented."):
1697+
# this should raise a NotImplementedError
1698+
BadSearchCV(SVC(), cv=5).fit(X, y)
1699+
1700+
16811701
def test_deprecated_grid_search_iid():
16821702
depr_message = ("The default of the `iid` parameter will change from True "
16831703
"to False in version 0.22")

0 commit comments

Comments
 (0)
0