8000 BUG: restore score functionality in grid_search · seckcoder/scikit-learn@021a4e0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 021a4e0

Browse files
committed
BUG: restore score functionality in grid_search
1 parent 34334f5 commit 021a4e0

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

sklearn/grid_search.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,6 @@ def fit(self, X, y=None, **params):
370370
self.predict = best_estimator.predict
371371
if hasattr(best_estimator, 'predict_proba'):
372372
self.predict_proba = best_estimator.predict_proba
373-
if hasattr(best_estimator, 'score'):
374-
self.score_ = best_estimator.score
375373

376374
# Store the computed scores
377375
# XXX: the name is too specific, it shouldn't have
@@ -383,8 +381,12 @@ def fit(self, X, y=None, **params):
383381
return self
384382

385383
def score(self, X, y=None):
386-
# This method is overridden during the fit if the best estimator
387-
# found has a score function.
384+
if hasattr(self.best_estimator, 'score'):
385+
return self.best_estimator.score(X, y)
386+
if self.score_func is None:
387+
raise ValueError("No score function explicitly defined, "
388+
"and the estimator doesn't provide one %s"
389+
% self.best_estimator)
388390
y_predicted = self.predict(X)
389391
return self.score_func(y, y_predicted)
390392

sklearn/tests/test_grid_search.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,15 @@ def score(self, X=None, Y=None F9CF ):
4242
def test_grid_search():
4343
"""Test that the best estimator contains the right value for foo_param"""
4444
clf = MockClassifier()
45-
cross_validation = GridSearchCV(clf, {'foo_param': [1, 2, 3]})
45+
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]})
4646
# make sure it selects the smallest parameter in case of ties
47-
assert_equal(cross_validation.fit(X, y).best_estimator_.foo_param, 2)
47+
grid_search.fit(X, y)
48+
assert_equal(grid_search.best_estimator_.foo_param, 2)
4849

4950
for i, foo_i in enumerate([1, 2, 3]):
50-
assert cross_validation.grid_scores_[i][0] == {'foo_param': foo_i}
51+
assert grid_search.grid_scores_[i][0] == {'foo_param': foo_i}
52+
# Smoke test the score:
53+
grid_search.score(X, y)
5154

5255

5356
def test_grid_search_error():
@@ -101,6 +104,9 @@ def test_grid_search_sparse_score_func():
101104

102105
assert_array_equal(y_pred, y_pred2)
103106
assert_equal(C, C2)
107+
# Smoke test the score
108+
#np.testing.assert_allclose(f1_score(cv.predict(X_[:180]), y[:180]),
109+
# cv.score(X_[:180], y[:180]))
104110

105111

106112
class BrokenClassifier(BaseEstimator):

0 commit comments

Comments
 (0)
0