8000 check_scorable returns scorer · scikit-learn/scikit-learn@1fa3ec3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1fa3ec3

Browse files
check_scorable returns scorer
1 parent 389ed8d commit 1fa3ec3

File tree

6 files changed

+86
-62
lines changed

6 files changed

+86
-62
lines changed

sklearn/cross_validation.py

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .utils.fixes import unique
2828
from .externals.joblib import Parallel, delayed
2929
from .externals.six import string_types, with_metaclass
30-
from .metrics.scorer import _deprecate_loss_and_score_funcs
30+
from .metrics.scorer import check_scorable
3131

3232
__all__ = ['Bootstrap',
3333
'KFold',
@@ -1087,9 +1087,7 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
10871087
"""
10881088
X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True)
10891089
cv = _check_cv(cv, X, y, classifier=is_classifier(estimator))
1090-
_check_scorable(estimator, score_func=score_func, scoring=scoring)
1091-
scorer = _deprecate_loss_and_score_funcs(score_func=score_func,
1092-
scoring=scoring)
1090+
scorer = check_scorable(estimator, score_func=score_func, scoring=scoring)
10931091
# We clone the estimator to make sure that all the folds are
10941092
# independent, and that it is pickle-able.
10951093
parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
@@ -1168,20 +1166,12 @@ def _fit(fit_function, X_train, y_train, **fit_params):
11681166
def _score(estimator, X_test, y_test, scorer):
11691167
"""Compute the score of an estimator on a given test set."""
11701168
if y_test is None:
1171-
if scorer is None:
1172-
score = estimator.score(X_test)
1173-
else:
1174-
score = scorer(estimator, X_test)
1169+
score = scorer(estimator, X_test)
11751170
else:
1176-
if scorer is None:
1177-
score = estimator.score(X_test, y_test)
1178-
else:
1179-
score = scorer(estimator, X_test, y_test)
1180-
1171+
score = scorer(estimator, X_test, y_test)
11811172
if not isinstance(score, numbers.Number):
11821173
raise ValueError("scoring must return a number, got %s (%s) instead."
11831174
% (str(score), type(score)))
1184-
11851175
return score
11861176

11871177

@@ -1262,24 +1252,6 @@ def _check_cv(cv, X=None, y=None, classifier=False, warn_mask=False):
12621252
return cv
12631253

12641254

1265-
def _check_scorable(estimator, scoring=None, loss_func=None, score_func=None):
1266-
"""Check that estimator can be fitted and score can be computed."""
1267-
if (not hasattr(estimator, 'fit') or
1268-
not (hasattr(estimator, 'predict')
1269-
or hasattr(estimator, 'score'))):
1270-
raise TypeError("estimator should a be an estimator implementing"
1271-
" 'fit' and 'predict' or 'score' methods,"
1272-
" %s (type %s) was passed" %
1273-
(estimator, type(estimator)))
1274-
if (scoring is None and loss_func is None and score_func
1275-
is None):
1276-
if not hasattr(estimator, 'score'):
1277-
raise TypeError(
1278-
"If no scoring is specified, the estimator passed "
1279-
"should have a 'score' method. The estimator %s "
1280-
"does not." % estimator)
1281-
1282-
12831255
def permutation_test_score(estimator, X, y, score_func=None, cv=None,
12841256
n_permutations=100, n_jobs=1, labels=None,
12851257
random_state=0, verbose=0, scoring=None):
@@ -1351,11 +1323,7 @@ def permutation_test_score(estimator, X, y, score_func=None, cv=None,
13511323
"""
13521324
X, y = check_arrays(X, y, sparse_format='csr')
13531325
cv = _check_cv(cv, X, y, classifier=is_classifier(estimator))
1354-
scorer = _deprecate_loss_and_score_funcs(
1355-
loss_func=None,
1356-
score_func=score_func,
1357-
scoring=scoring
1358-
)
1326+
scorer = check_scorable(estimator, scoring=scoring, score_func=score_func)
13591327
random_state = check_random_state(random_state)
13601328

13611329
# We clone the estimator to make sure that all the folds are

sklearn/feature_selection/rfe.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from ..base import clone
1414
from ..base import is_classifier
1515
from ..cross_validation import _check_cv as check_cv
16-
from ..cross_validation import _check_scorable, _split, _score
16+
from ..cross_validation import _split, _score
1717
from .base import SelectorMixin
18-
from ..metrics.scorer import _deprecate_loss_and_score_funcs
18+
from ..metrics.scorer import check_scorable
1919

2020

2121
class RFE(BaseEstimator, MetaEstimatorMixin, SelectorMixin):
@@ -326,8 +326,8 @@ def fit(self, X, y):
326326
verbose=self.verbose - 1)
327327

328328
cv = check_cv(self.cv, X, y, is_classifier(self.estimator))
329-
_check_scorable(self.estimator, scoring=self.scoring,
330-
loss_func=self.loss_func)
329+
scorer = check_scorable(self.estimator, scoring=self.scoring,
330+
loss_func=self.loss_func)
331331
scores = np.zeros(X.shape[1])
332332

333333
# Cross-validation
@@ -345,11 +345,6 @@ def fit(self, X, y):
345345

346346
estimator = clone(self.estimator)
347347
estimator.fit(X_train_subset, y_train)
348-
349-
scorer = _deprecate_loss_and_score_funcs(
350-
loss_func=self.loss_func,
351-
scoring=self.scoring
352-
)
353348
score = _score(estimator, X_test_subset, y_test, scorer)
354349

355350
if self.verbose > 0:

sklearn/grid_search.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
from .base import BaseEstimator, is_classifier, clone
2424
from .base import MetaEstimatorMixin
2525
from .cross_validation import _check_cv as check_cv
26-
from .cross_validation import _check_scorable, _cross_val_score
26+
from .cross_validation import _cross_val_score
2727
from .externals.joblib import Parallel, delayed, logger
2828
from .externals import six
2929
from .utils import safe_mask, check_random_state
3030
from .utils.validation import _num_samples, check_arrays
31-
from .metrics.scorer import _deprecate_loss_and_score_funcs
31+
from .metrics.scorer import check_scorable
3232

3333

3434
__all__ = ['GridSearchCV', 'ParameterGrid', 'fit_grid_point',
@@ -308,8 +308,6 @@ def __init__(self, estimator, scoring=None, loss_func=None,
308308
self.cv = cv
309309
self.verbose = verbose
310310
self.pre_dispatch = pre_dispatch
311-
_check_scorable(self.estimator, scoring=self.scoring,
312-
loss_func=self.loss_func, score_func=self.score_func)
313311

314312
def score(self, X, y=None):
315313
"""Returns the score on the given test data and labels, if the search
@@ -360,13 +358,13 @@ def _fit(self, X, y, parameter_iterable):
360358

361359
estimator = self.estimator
362360
cv = self.cv
361+
self.scorer_ = check_scorable(self.estimator, scoring=self.scoring,
362+
loss_func=self.loss_func,
363+
score_func=self.score_func)
363364

364365
n_samples = _num_samples(X)
365366
X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr')
366367

367-
self.scorer_ = _deprecate_loss_and_score_funcs(
368-
self.loss_func, self.score_func, self.scoring)
369-
370368
if y is not None:
371369
if len(y) != n_samples:
372370
raise ValueError('Target variable (y) has a different number '

sklearn/learning_curve.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from .utils import check_arrays
1212
from .externals.joblib import Parallel, delayed
1313
from .metrics.scorer import get_scorer
14-
from .cross_validation import _check_scorable, _split, _fit, _score
14+
from .cross_validation import _split, _fit, _score
15+
from .metrics.scorer import check_scorable
1516

1617

1718
def learning_curve(estimator, X, y, train_sizes=np.linspace(0.1, 1.0, 10),
@@ -101,6 +102,7 @@ def learning_curve(estimator, X, y, train_sizes=np.linspace(0.1, 1.0, 10),
101102
X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True)
102103
# Make a list since we will be iterating multiple times over the folds
103104
cv = list(_check_cv(cv, X, y, classifier=is_classifier(estimator)))
105+
scorer = check_scorable(estimator, scoring=scoring)
104106

105107
# HACK as long as boolean indices are allowed in cv generators
106108
if cv[0][0].dtype == bool:
@@ -119,9 +121,6 @@ def learning_curve(estimator, X, y, train_sizes=np.linspace(0.1, 1.0, 10),
119121
if verbose > 0:
120122
print("[learning_curve] Training set sizes: " + str(train_sizes_abs))
121123

122-
_check_scorable(estimator, scoring=scoring)
123-
scorer = get_scorer(scoring)
124-
125124
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
126125
verbose=verbose)
127126
if exploit_incremental_learning:

sklearn/metrics/scorer.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,69 @@ def get_scorer(scoring):
198198
return scorer
199199

200200

201+
class _passthrough_scorer(object):
202+
"""Callable that wraps estimator.score"""
203+
def __call__(self, estimator, *args, **kwargs):
204+
return estimator.score(*args, **kwargs)
205+
206+
207+
def check_scorable(estimator, scoring=None, loss_func=None, score_func=None):
208+
"""Check if estimator can be scored.
209+
210+
A TypeError will be thrown if the estimator cannot be scored.
211+
212+
Parameters
213+
----------
214+
estimator : estimator object implementing 'fit'
215+
The object to use to fit the data.
216+
217+
scoring : string, callable or None, optional, default: None
218+
A string (see model evaluation documentation) or
219+
a scorer callable object / function with signature
220+
``scorer(estimator, X, y)``.
221+
222+
loss_func : callable or None, optional, default: None
223+
A loss function callable object / function with signature
224+
``loss_func(estimator, X, y)``.
225+
226+
score_func : callable or None, optional, default: None
227+
A scoring function with signature
228+
``score_func(estimator, X, y)``.
229+
230+
Returns
231+
-------
232+
scoring : callable
233+
A scorer callable object / function with signature
234+
``scorer(estimator, X, y)``.
235+
"""
236+
if not hasattr(estimator, 'fit'):
237+
raise TypeError("estimator should a be an estimator implementing "
238+
"'fit' method, %s (type %s) was passed" %
239+
(estimator, type(estimator)))
240+
241+
if scoring is None and loss_func is None and score_func is None:
242+
if hasattr(estimator, 'score'):
243+
return _passthrough_scorer()
244+
else:
245+
raise TypeError(
246+
"If no scoring is specified, the estimator passed should "
247+
"have a 'score' method. The estimator %s (type %s) "
248+
"does not." % (estimator, type(estimator)))
249+
else:
250+
if hasattr(estimator, 'predict'):
251+
scorer = _deprecate_loss_and_score_funcs(scoring=scoring,
252+
loss_func=loss_func, score_func=score_func)
253+
if scorer is None:
254+
return ValueError("no scoring")
255+
else:
256+
return scorer
257+
else:
258+
raise TypeError(
259+
"If a scoring is specified, the estimator passed should "
260+
"have a 'predict' method. The estimator %s (type %s) "
261+
"does not." % (estimator, type(estimator)))
262+
263+
201264
def make_scorer(score_func, greater_is_better=True, needs_proba=False,
202265
needs_threshold=False, **kwargs):
203266
"""Make a scorer from a performance metric or loss function.

sklearn/tests/test_grid_search.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,9 @@ def test_grid_search_no_score():
190190
assert_equal(grid_search.score(X, y), grid_search_no_score.score(X, y))
191191

192192
# giving no scoring function raises an error
193-
assert_raise_message(TypeError, "no scoring",
194-
GridSearchCV, clf_no_score, {'C': Cs})
193+
grid_search_no_score = GridSearchCV(clf_no_score, {'C': Cs})
194+
assert_raise_message(TypeError, "no scoring", grid_search_no_score.fit,
195+
[[1]])
195196

196197

197198
def test_trivial_grid_scores():
@@ -494,9 +495,9 @@ def test_bad_estimator():
494495
# test grid-search with clustering algorithm which doesn't support
495496
# "predict"
496497
sc = SpectralClustering()
497-
assert_raises(TypeError, GridSearchCV, sc,
498-
param_grid=dict(gamma=[.1, 1, 10]),
499-
scoring='ari')
498+
grid_search = GridSearchCV(sc, param_grid=dict(gamma=[.1, 1, 10]),
499+
scoring='ari')
500+
assert_raises(TypeError, grid_search.fit, [[1]])
500501

501502

502503
def test_param_sampler():

0 commit comments

Comments
 (0)
0