8000 ENH GridSearchCV and cross_val_score check whether the returned score… · amueller/scikit-learn@2d9cb81 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2d9cb81

Browse files
committed
ENH GridSearchCV and cross_val_score check whether the returned score is actually a number, not an array (otherwise cross_val_score returns bogus).
1 parent e9556eb commit 2d9cb81

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

sklearn/cross_validation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,9 @@ def _cross_val_score(estimator, X, y, scorer, train, test, verbose,
10671067
score = estimator.score(X_test, y_test)
10681068
else:
10691069
score = scorer(estimator, X_test, y_test)
1070+
if not isinstance(score, numbers.Number):
1071+
raise ValueError("scoring must return a number, got %s (%s)"
1072+
" instead." % (str(score), type(score)))
10701073
if verbose > 1:
10711074
print("score: %f" % score)
10721075
return score

sklearn/grid_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from itertools import product
1111
import time
1212
import warnings
13+
import numbers
1314

1415
import numpy as np
1516

@@ -123,6 +124,10 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer,
123124
else:
124125
this_score = clf.score(X_test)
125126

127+
if not isinstance(this_score, numbers.Number):
128+
raise ValueError("scoring must return a number, got %s (%s)"
129+
" instead." % (str(this_score), type(this_score)))
130+
126131
if verbose > 2:
127132
msg += ", score=%f" % this_score
128133
if verbose > 1:

sklearn/metrics/tests/test_score_objects.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from sklearn.linear_model import Ridge, LogisticRegression
1010
from sklearn.tree import DecisionTreeClassifier
1111
from sklearn.datasets import make_blobs, load_diabetes
12-
from sklearn.cross_validation import train_test_split
12+
from sklearn.cross_validation import train_test_split, cross_val_score
13+
from sklearn.grid_search import GridSearchCV
1314

1415

1516
def test_classification_scores():
@@ -74,3 +75,15 @@ def test_unsupervised_scores():
7475
score1 = scorers['ari'](km, X_test, y_test)
7576
score2 = adjusted_rand_score(y_test, km.predict(X_test))
7677
assert_almost_equal(score1, score2)
78+
79+
80+
def test_raises_on_score_list():
81+
# test that when a list of scores is returned, we raise proper errors.
82+
X, y = make_blobs(random_state=0)
83+
f1_scorer_no_average = AsScorer(f1_score, average=None)
84+
clf = DecisionTreeClassifier()
85+
assert_raises(ValueError, cross_val_score, clf, X, y,
86+
scoring=f1_scorer_no_average)
87+
grid_search = GridSearchCV(clf, scoring=f1_scorer_no_average,
88+
param_grid={'max_depth': [1, 2]})
89+
assert_raises(ValueError, grid_search.fit, X, y)

0 commit comments

Comments
 (0)
0