8000 [MRG+1] add scorer based on explained_variance_score (#9259) · scikit-learn/scikit-learn@9f91ec7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9f91ec7

Browse files
qinhanmin2014lesteve
authored andcommitted
[MRG+1] add scorer based on explained_variance_score (#9259)
1 parent 6d4ae1b commit 9f91ec7

File tree

4 files changed

+14
-7
lines changed

4 files changed

+14
-7
lines changed

doc/modules/model_evaluation.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Scoring Function
8181
'v_measure_score' :func:`metrics.v_measure_score`
8282

8383
**Regression**
84+
'explained_variance' :func:`metrics.explained_variance_score`
8485
'neg_mean_absolute_error' :func:`metrics.mean_absolute_error`
8586
'neg_mean_squared_error' :func:`metrics.mean_squared_error`
8687
'neg_mean_squared_log_error' :func:`metrics.mean_squared_log_error`
@@ -101,7 +102,7 @@ Usage examples:
101102
>>> model = svm.SVC()
102103
>>> cross_val_score(model, X, y, scoring='wrong_choice')
103104
Traceback (most recent call last):
104-
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'average_precision', 'completeness_score', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'fowlkes_mallows_score', 'homogeneity_score', 'mutual_info_score', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'normalized_mutual_info_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc', 'v_measure_score']
105+
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'average_precision', 'completeness_score', 'explained_variance', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'fowlkes_mallows_score', 'homogeneity_score', 'mutual_info_score', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'normalized_mutual_info_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc', 'v_measure_score']
105106

106107
.. note::
107108

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ Model selection and evaluation
139139
:class:`model_selection.RepeatedStratifiedKFold`.
140140
:issue:`8120` by `Neeraj Gangwar`_.
141141

142+
- Added a scorer based on :class:`metrics.explained_variance_score`.
143+
:issue:`9259` by `Hanmin Qin <https://github.com/qinhanmin2014>`_.
144+
142145
Miscellaneous
143146

144147
- Validation that input data contains no NaN or inf can now be suppressed

sklearn/metrics/scorer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from . import (r2_score, median_absolute_error, mean_absolute_error,
2727
mean_squared_error, mean_squared_log_error, accuracy_score,
2828
f1_score, roc_auc_score, average_precision_score,
29-
precision_score, recall_score, log_loss)
29+
precision_score, recall_score, log_loss,
30+
explained_variance_score)
3031

3132
from .cluster import adjusted_rand_score
3233
from .cluster import homogeneity_score
@@ -463,6 +464,7 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
463464

464465

465466
# Standard regression scores
467+
explained_variance_scorer = make_scorer(explained_variance_score)
466468
r2_scorer = make_scorer(r2_score)
467469
neg_mean_squared_error_scorer = make_scorer(mean_squared_error,
468470
greater_is_better=False)
@@ -525,7 +527,8 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
525527
fowlkes_mallows_scorer = make_scorer(fowlkes_mallows_score)
526528

527529

528-
SCORERS = dict(r2=r2_scorer,
530+
SCORERS = dict(explained_variance=explained_variance_scorer,
531+
r2=r2_scorer,
529532
neg_median_absolute_error=neg_median_absolute_error_scorer,
530533
neg_mean_absolute_error=neg_mean_absolute_error_scorer,
531534
neg_mean_squared_error=neg_mean_squared_error_scorer,

sklearn/metrics/tests/test_score_objects.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from sklearn.svm import LinearSVC
3030
from sklearn.pipeline import make_pipeline
3131
from sklearn.cluster import KMeans
32-
from sklearn.dummy import DummyRegressor
3332
from sklearn.linear_model import Ridge, LogisticRegression
3433
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
3534
from sklearn.datasets import make_blobs
@@ -42,8 +41,9 @@
4241
from sklearn.externals import joblib
4342

4443

45-
REGRESSION_SCORERS = ['r2', 'neg_mean_absolute_error',
46-
'neg_mean_squared_error', 'neg_mean_squared_log_error',
44+
REGRESSION_SCORERS = ['explained_variance', 'r2',
45+
'neg_mean_absolute_error', 'neg_mean_squared_error',
46+
'neg_mean_squared_log_error',
4747
'neg_median_absolute_error', 'mean_absolute_error',
4848
'mean_squared_error', 'median_absolute_error']
4949

@@ -68,7 +68,7 @@
6868

6969
def _make_estimators(X_train, y_train, y_ml_train):
7070
# Make estimators that make sense to test various scoring methods
71-
sensible_regr = DummyRegressor(strategy='median')
71+
sensible_regr = DecisionTreeRegressor(random_state=0)
7272
sensible_regr.fit(X_train, y_train)
7373
sensible_clf = DecisionTreeClassifier(random_state=0)
7474
sensible_clf.fit(X_train, y_train)

0 commit comments

Comments
 (0)
0