8000 BaseSearchCV._run_search raises NotImplementedError instead of being an abstractmethod by adrinjalali · Pull Request #12182 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

BaseSearchCV._run_search raises NotImplementedError instead of being an abstractmethod #12182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,11 +580,10 @@ def classes_(self):
self._check_is_fitted("classes_")
return self.best_estimator_.classes_

@abstractmethod
def _run_search(self, evaluate_candidates):
"""Repeatedly calls `evaluate_candidates` to conduct a search.

This method, implemented in sub-classes, makes it is possible to
This method, implemented in sub-classes, makes it possible to
customize the the scheduling of evaluations: GridSearchCV and
RandomizedSearchCV schedule evaluations for their whole parameter
search space at once but other more sequential approaches are also
Expand Down Expand Up @@ -613,6 +612,7 @@ def _run_search(self, evaluate_candidates):
if score[0] < score[1]:
evaluate_candidates([{'C': 0.1}])
"""
raise NotImplementedError("_run_search not implemented.")

def fit(self, X, y=None, groups=None, **fit_params):
"""Run fit with all sets of parameters.
Expand Down
22 changes: 21 additions & 1 deletion sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def test_parameter_grid():

@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
@pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22

def test_grid_search():
# Test that the best estimator contains the right value for foo_param
clf = MockClassifier()
Expand Down Expand Up @@ -1678,6 +1677,27 @@ def _run_search(self, evaluate):
"Attribute %s not equal" % attr


def test__custom_fit_no_run_search():
class NoRunSearchSearchCV(BaseSearchCV):
def __init__(self, estimator, **kwargs):
super(NoRunSearchSearchCV, self).__init__(estimator, **kwargs)

def fit(self, X, y=None, groups=None, **fit_params):
return self

# this should not raise any exceptions
NoRunSearchSearchCV(SVC(), cv=5).fit(X, y)

class BadSearchCV(BaseSearchCV):
def __init__(self, estimator, **kwargs):
super(BadSearchCV, self).__init__(estimator, **kwargs)

with pytest.raises(NotImplementedError,
match="_run_search not implemented."):
# this should raise a NotImplementedError
BadSearchCV(SVC(), cv=5).fit(X, y)


def test_deprecated_grid_search_iid():
depr_message = ("The default of the `iid` parameter will change from True "
"to False in version 0.22")
Expand Down
0