8000 [MRG + 1] Add fowlkess-mallows and other supervised cluster metrics t… · scikit-learn/scikit-learn@2f7f5a1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2f7f5a1

Browse files
raghavrvagramfort
authored andcommitted
[MRG + 1] Add fowlkess-mallows and other supervised cluster metrics to SCORERS dict so it can be used in hyper-param search (#8117)
* Add supervised cluster metrics to metrics.scorers * Add all the supervised cluster metrics to the tests * Add test for fowlkes_mallows_score in unsupervised grid search * COSMIT: Clarify comment on CLUSTER_SCORERS * Fix doctest
1 parent 8695ff5 commit 2f7f5a1

File tree

4 files changed

+53
-12
lines changed

4 files changed

+53
-12
lines changed

doc/modules/model_evaluation.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Usage examples:
9494
>>> model = svm.SVC()
9595
>>> cross_val_score(model, X, y, scoring='wrong_choice')
9696
Traceback (most recent call last):
97-
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc']
97+
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'average_precision', 'completeness_score', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'fowlkes_mallows_score', 'homogeneity_score', 'mutual_info_score', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'normalized_mutual_info_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc', 'v_measure_score']
9898

9999
.. note::
100100

sklearn/metrics/scorer.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,16 @@
2727
mean_squared_error, mean_squared_log_error, accuracy_score,
2828
f1_score, roc_auc_score, average_precision_score,
2929
precision_score, recall_score, log_loss)
30+
3031
from .cluster import adjusted_rand_score
32+
from .cluster import homogeneity_score
33+
from .cluster import completeness_score
34+
from .cluster import v_measure_score
35+
from .cluster import mutual_info_score
36+
from .cluster import adjusted_mutual_info_score
37+
from .cluster import normalized_mutual_info_score
38+
from .cluster import fowlkes_mallows_score
39+
3140
from ..utils.multiclass import type_of_target
3241
from ..externals import six
3342
from ..base import is_regressor
@@ -393,6 +402,14 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
393402

394403
# Clustering scores
395404
adjusted_rand_scorer = make_scorer(adjusted_rand_score)
405+
homogeneity_scorer = make_scorer(homogeneity_score)
406+
completeness_scorer = make_scorer(completeness_score)
407+
v_measure_scorer = make_scorer(v_measure_score)
408+
mutual_info_scorer = make_scorer(mutual_info_score)
409+
adjusted_mutual_info_scorer = make_scorer(adjusted_mutual_info_score)
410+
normalized_mutual_info_scorer = make_scorer(normalized_mutual_info_score)
411+
fowlkes_mallows_scorer = make_scorer(fowlkes_mallows_score)
412+
396413

397414
SCORERS = dict(r2=r2_scorer,
398415
neg_median_absolute_error=neg_median_absolute_error_scorer,
@@ -406,7 +423,16 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
406423
average_precision=average_precision_scorer,
407424
log_loss=log_loss_scorer,
408425
neg_log_loss=neg_log_loss_scorer,
409-
adjusted_rand_score=adjusted_rand_scorer)
426+
# Cluster metrics that use supervised evaluation
427+
adjusted_rand_score=adjusted_rand_scorer,
428+
homogeneity_score=homogeneity_scorer,
429+
completeness_score=completeness_scorer,
430+
v_measure_score=v_measure_scorer,
431+
mutual_info_score=mutual_info_scorer,
432+
adjusted_mutual_info_score=adjusted_mutual_info_scorer,
433+
normalized_mutual_info_score=normalized_mutual_info_scorer,
434+
fowlkes_mallows_score=fowlkes_mallows_scorer)
435+
410436

411437
for name, metric in [('precision', precision_score),
412438
('recall', recall_score), ('f1', f1_score)]:

sklearn/metrics/tests/test_score_objects.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sklearn.base import BaseEstimator
1919
from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score,
2020
log_loss, precision_score, recall_score)
21-
from sklearn.metrics.cluster import adjusted_rand_score
21+
from sklearn.metrics import cluster as cluster_module
2222
from sklearn.metrics.scorer import (check_scoring, _PredictScorer,
2323
_passthrough_scorer)
2424
from sklearn.metrics import make_scorer, get_scorer, SCORERS
@@ -47,9 +47,17 @@
4747
'roc_auc', 'average_precision', 'precision',
4848
'precision_weighted', 'precision_macro', 'precision_micro',
4949
'recall', 'recall_weighted', 'recall_macro', 'recall_micro',
50-
'neg_log_loss', 'log_loss',
51-
'adjusted_rand_score' # not really, but works
52-
]
50+
'neg_log_loss', 'log_loss']
51+
52+
# All supervised cluster scorers (They behave like classification metric)
53+
CLUSTER_SCORERS = ["adjusted_rand_score",
54+
"homogeneity_score",
55+
"completeness_score",
56+
"v_measure_score",
57+
"mutual_info_score",
58+
"adjusted_mutual_info_score",
59+
"normalized_mutual_info_score",
60+
"fowlkes_mallows_score"]
5361

5462
MULTILABEL_ONLY_SCORERS = ['precision_samples', 'recall_samples', 'f1_samples']
5563

@@ -65,6 +73,7 @@ def _make_estimators(X_train, y_train, y_ml_train):
6573
return dict(
6674
[(name, sensible_regr) for name in REGRESSION_SCORERS] +
6775
[(name, sensible_clf) for name in CLF_SCORERS] +
76+
[(name, sensible_clf) for name in CLUSTER_SCORERS] +
6877
[(name, sensible_ml_clf) for name in MULTILABEL_ONLY_SCORERS]
6978
)
7079

@@ -330,16 +339,16 @@ def test_thresholded_scorers_multilabel_indicator_data():
330339
assert_almost_equal(score1, score2)
331340

332341

333-
def test_unsupervised_scorers():
342+
def test_supervised_cluster_scorers():
334343
# Test clustering scorers against gold standard labeling.
335-
# We don't have any real unsupervised Scorers yet.
336344
X, y = make_blobs(random_state=0, centers=2)
337345
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
338346
km = KMeans(n_clusters=3)
339347
km.fit(X_train)
340-
score1 = get_scorer('adjusted_rand_score')(km, X_test, y_test)
341-
score2 = adjusted_rand_score(y_test, km.predict(X_test))
342-
assert_almost_equal(score1, score2)
348+
for name in CLUSTER_SCORERS:
349+
score1 = get_scorer(name)(km, X_test, y_test)
350+
score2 = getattr(cluster_module, name)(y_test, km.predict(X_test))
351+
assert_almost_equal(score1, score2)
343352

344353

345354
@ignore_warnings
@@ -445,4 +454,4 @@ def test_scoring_is_not_metric():
445454
assert_raises_regexp(ValueError, 'make_scorer', check_scoring,
446455
Ridge(), r2_score)
447456
assert_raises_regexp(ValueError, 'make_scorer', check_scoring,
448-
KMeans(), adjusted_rand_score)
457+
KMeans(), cluster_module.adjusted_rand_score)

sklearn/model_selection/tests/test_search.py

+6
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,12 @@ def test_unsupervised_grid_search():
542542
# ARI can find the right number :)
543543
assert_equal(grid_search.best_params_["n_clusters"], 3)
544544

545+
grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]),
546+
scoring='fowlkes_mallows_score')
547+
grid_search.fit(X, y)
548+
# So can FMS ;)
549+
assert_equal(grid_search.best_params_["n_clusters"], 3)
550+
545551
# Now without a score, and without y
546552
grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]))
547553
grid_search.fit(X)

0 commit comments

Comments
 (0)
0