From 85041214a2598be4c86f98ab0235002a0c77fcbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=B6lsterl?= Date: Tue, 5 Jan 2021 13:14:02 +0100 Subject: [PATCH 1/5] Add prediction_strength_score function --- sklearn/metrics/__init__.py | 2 + sklearn/metrics/cluster/__init__.py | 4 +- sklearn/metrics/cluster/_unsupervised.py | 68 +++++++++- .../cluster/tests/test_unsupervised.py | 123 ++++++++++++++++++ 4 files changed, 195 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 84e7c98e29324..d61492fe7a297 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -53,6 +53,7 @@ from .cluster import calinski_harabasz_score from .cluster import v_measure_score from .cluster import davies_bouldin_score +from .cluster import prediction_strength_score from .pairwise import euclidean_distances from .pairwise import nan_euclidean_distances @@ -156,6 +157,7 @@ 'precision_recall_curve', 'precision_recall_fscore_support', 'precision_score', + 'prediction_strength_score', 'r2_score', 'rand_score', 'recall_score', diff --git a/sklearn/metrics/cluster/__init__.py b/sklearn/metrics/cluster/__init__.py index 9e116b40e31da..f9b99ac5c6c02 100644 --- a/sklearn/metrics/cluster/__init__.py +++ b/sklearn/metrics/cluster/__init__.py @@ -23,6 +23,7 @@ from ._unsupervised import silhouette_score from ._unsupervised import calinski_harabasz_score from ._unsupervised import davies_bouldin_score +from ._unsupervised import prediction_strength_score from ._bicluster import consensus_score __all__ = ["adjusted_mutual_info_score", "normalized_mutual_info_score", @@ -32,4 +33,5 @@ "homogeneity_score", "mutual_info_score", "v_measure_score", "fowlkes_mallows_score", "entropy", "silhouette_samples", "silhouette_score", "calinski_harabasz_score", - "davies_bouldin_score", "consensus_score"] + "davies_bouldin_score", "consensus_score", + "prediction_strength_score"] diff --git a/sklearn/metrics/cluster/_unsupervised.py b/sklearn/metrics/cluster/_unsupervised.py index c597277a55b31..e6b2ea8511d31 100644 --- a/sklearn/metrics/cluster/_unsupervised.py +++ b/sklearn/metrics/cluster/_unsupervised.py @@ -5,11 +5,14 @@ # Thierry Guillemot # License: BSD 3 clause - +from itertools import chain +from itertools import permutations import functools import numpy as np +from ...utils import check_array +from ...utils import check_consistent_length from ...utils import check_random_state from ...utils import check_X_y from ...utils import _safe_indexing @@ -361,3 +364,66 @@ def davies_bouldin_score(X, labels): combined_intra_dists = intra_dists[:, None] + intra_dists scores = np.max(combined_intra_dists / centroid_distances, axis=1) return np.mean(scores) + + +def prediction_strength_score(labels_train, labels_test): + """Compute the prediction strength score. + + For each test cluster, we compute the proportion of observation pairs + in that cluster that are also assigned to the same cluster by the + training set centroids. The prediction strength is the minimum of this + quantity over the k test clusters. + + The best value is 1.0 (if the assignments of `labels_train` and + `labels_test` are identical) and the worst value is 0 (if all samples of + one cluster of `labels_test` are not co-members of some cluster in + `labels_train`). + + Parameters + ---------- + labels_train : array-like, shape (``n_test_samples``,) + Predicted labels for each sample in the the test data + based on clusters derived from independent training data. + + labels_test : array-like, shape (``n_test_samples``,) + Predicted labels for each sample in the test data + based on clusters derived from the same data. + + Returns + ------- + score : float + The resulting prediction strength score. + + References + ---------- + .. [1] `Robert Tibshirani and Guenther Walther (2005). "Cluster Validation + by Prediction Strength". Journal of Computational and Graphical Statistics, + 14(3), 511-528. _` + """ + check_consistent_length(labels_train, labels_test) + + labels_train = check_array(labels_train, dtype=np.int32, ensure_2d=False) + labels_test = check_array(labels_test, dtype=np.int32, ensure_2d=False) + + clusters = set(chain(labels_train, labels_test)) + n_clusters = len(clusters) + if n_clusters == 1: + return 1.0 # by definition + + strength = 1.0 + for k in clusters: + # samples assigned to k-th cluster based on test data + samples_test_k = np.flatnonzero(labels_test == k) + cluster_test_size = samples_test_k.shape[0] + + if cluster_test_size < 2: + continue + + matches = 0 + for i, j in permutations(range(cluster_test_size), 2): + if labels_train[samples_test_k[j]] == labels_train[samples_test_k[i]]: + matches += 1 + + strength = min(strength, matches / (cluster_test_size * (cluster_test_size - 1.))) + + return strength diff --git a/sklearn/metrics/cluster/tests/test_unsupervised.py b/sklearn/metrics/cluster/tests/test_unsupervised.py index 354b6c94a7548..da4c6cd9dc62d 100644 --- a/sklearn/metrics/cluster/tests/test_unsupervised.py +++ b/sklearn/metrics/cluster/tests/test_unsupervised.py @@ -10,6 +10,7 @@ from sklearn.metrics import pairwise_distances from sklearn.metrics.cluster import calinski_harabasz_score from sklearn.metrics.cluster import davies_bouldin_score +from sklearn.metrics.cluster import prediction_strength_score def test_silhouette(): @@ -250,3 +251,125 @@ def test_davies_bouldin_score(): X = ([[0, 0], [2, 2], [3, 3], [5, 5]]) labels = [0, 0, 1, 2] pytest.approx(davies_bouldin_score(X, labels), (5. / 4) / 3) + + +def test_prediction_strength_score(): + with pytest.raises(ValueError, match=r"Found array with 0 sample\(s\)"): + prediction_strength_score([], []) + + with pytest.raises(ValueError, + match="Found input variables with inconsistent numbers " + "of samples"): + prediction_strength_score([1], []) + with pytest.raises(ValueError, + match="Found input variables with inconsistent numbers " + "of samples"): + prediction_strength_score([], [1]) + with pytest.raises(ValueError, + match="Found input variables with inconsistent numbers " + "of samples"): + prediction_strength_score([1, 1], [1, 2, 3]) + + assert 1. == prediction_strength_score([1], [1]) + assert 1. == prediction_strength_score([1], [2]) + assert 1. == prediction_strength_score([2], [1]) + assert 1. == prediction_strength_score([1, 1, 1], [1, 1, 1]) + assert 1. == prediction_strength_score([1, 1, 1], [2, 2, 2]) + assert 1. == prediction_strength_score([2, 2, 2], [1, 1, 1]) + + assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2], + [0, 0, 1, 1, 2, 2]) + assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2], + [2, 2, 1, 1, 0, 0]) + assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2], + [0, 0, 2, 2, 1, 1]) + assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2], + [1, 1, 2, 2, 0, 0]) + assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2], + [2, 2, 0, 0, 1, 1]) + assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2], + [1, 1, 0, 0, 2, 2]) + assert 1. == prediction_strength_score([3, 3, 6, 6, 9, 9], + [11, 11, 4, 4, 14, 14]) + + # 3 pairs in each cluster, 2 pairs (1-3 and 2-3) are assigned + # different clusters + assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2], + [1, 1, 2, 2, 2, 1]) + assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2], + [2, 2, 1, 1, 1, 2]) + assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2], + [2, 2, 3, 3, 3, 2]) + + # 3 pairs in each cluster, 2 pairs (1-2 and 1-3) are assigned + # different clusters + assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2], + [2, 1, 1, 1, 2, 2]) + assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2], + [1, 2, 2, 2, 1, 1]) + assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2], + [3, 2, 2, 2, 3, 3]) + + # 6 pairs in each cluster, 3 pairs (1-4, 2-4, and 3-4) are assigned + # different clusters + assert .5 == prediction_strength_score([1, 1, 1, 1, 2, 2, 2, 2], + [1, 1, 1, 2, 2, 2, 2, 1]) + assert .5 == prediction_strength_score([1, 1, 1, 1, 2, 2, 2, 2], + [2, 2, 2, 1, 1, 1, 1, 2]) + assert .5 == prediction_strength_score([1, 1, 1, 1, 2, 2, 2, 2], + [2, 2, 2, 3, 3, 3, 3, 2]) + + # 1 pair in each cluster, all clusters are completely different + assert .0 == prediction_strength_score([1, 1, 2, 2], [1, 2, 1, 2]) + assert .0 == prediction_strength_score([1, 1, 2, 2], [2, 1, 2, 1]) + + # 3 pairs in each clusters, all clusters are completely different + assert .0 == prediction_strength_score([1, 2, 3, 1, 2, 3], + [1, 1, 1, 2, 2, 2]) + assert .0 == prediction_strength_score([1, 2, 3, 1, 2, 3], + [2, 2, 2, 1, 1, 1]) + assert .0 == prediction_strength_score([1, 2, 3, 1, 2, 3], + [3, 3, 3, 9, 9, 9]) + + # 1 pair in each cluster, clusters 1 and 3 are completely different + assert .0 == prediction_strength_score([1, 1, 2, 2, 3, 3], + [1, 3, 2, 2, 1, 3]) + + # different number of clusters, 3 pairs and 1 cluster, + # 2 pairs (1-3, 2-3) are assigned different clusters + assert 1. / 3. == prediction_strength_score([1, 1, 2], [1, 1, 1]) + + # different number of clusters, 2 pairs in each cluster + # all pairs are assigned the same cluster + assert 1. == prediction_strength_score([1, 1, 1, 1], [1, 1, 2, 2]) + + # different number of clusters, all clusters are completely different + assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2], + [1, 3, 2, 2, 1, 3]) + assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2], + [2, 3, 1, 1, 2, 3]) + assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2], + [2, 1, 3, 3, 2, 1]) + assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2], + [3, 1, 2, 2, 3, 1]) + assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2], + [3, 2, 1, 1, 3, 1]) + + # different number of clusters, cluster 3 is completely different + assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2], + [3, 1, 1, 2, 3, 1]) + assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2], + [1, 3, 3, 2, 1, 3]) + assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2], + [2, 1, 1, 3, 2, 1]) + assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2], + [2, 3, 3, 1, 2, 3]) + assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2], + [3, 2, 2, 1, 3, 2]) + + # different number of clusters, clusters 1 and 2 have each + # 2 different pairs + assert 1. / 3. == prediction_strength_score([3, 1, 1, 2, 3, 3], + [1, 1, 1, 2, 2, 2]) + assert 1. / 3. == prediction_strength_score([3, 1, 1, 2, 3, 3], + [2, 2, 2, 1, 1, 1]) From 5ef6168c6d15b95739fac502eea934c34a7672ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=B6lsterl?= Date: Tue, 5 Jan 2021 13:17:04 +0100 Subject: [PATCH 2/5] Use contingency_matrix in prediction_strength_score --- sklearn/metrics/cluster/_unsupervised.py | 29 ++++++++---------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/sklearn/metrics/cluster/_unsupervised.py b/sklearn/metrics/cluster/_unsupervised.py index e6b2ea8511d31..b2ae7f1c2707c 100644 --- a/sklearn/metrics/cluster/_unsupervised.py +++ b/sklearn/metrics/cluster/_unsupervised.py @@ -5,8 +5,6 @@ # Thierry Guillemot # License: BSD 3 clause -from itertools import chain -from itertools import permutations import functools import numpy as np @@ -19,6 +17,7 @@ from ..pairwise import pairwise_distances_chunked from ..pairwise import pairwise_distances from ...preprocessing import LabelEncoder +from ._supervised import contingency_matrix from ...utils.validation import _deprecate_positional_args @@ -405,25 +404,15 @@ def prediction_strength_score(labels_train, labels_test): labels_train = check_array(labels_train, dtype=np.int32, ensure_2d=False) labels_test = check_array(labels_test, dtype=np.int32, ensure_2d=False) - clusters = set(chain(labels_train, labels_test)) - n_clusters = len(clusters) + n_clusters = max(np.unique(labels_train).shape[0], + np.unique(labels_test).shape[0]) if n_clusters == 1: return 1.0 # by definition - strength = 1.0 - for k in clusters: - # samples assigned to k-th cluster based on test data - samples_test_k = np.flatnonzero(labels_test == k) - cluster_test_size = samples_test_k.shape[0] + C = contingency_matrix(labels_train, labels_test) + pairs_matching = (C * (C - 1) / 2).sum(axis=0) + M = C.sum(axis=0) + pairs_total = (M * (M - 1) / 2) + nz = pairs_total.nonzero()[0] - if cluster_test_size < 2: - continue - - matches = 0 - for i, j in permutations(range(cluster_test_size), 2): - if labels_train[samples_test_k[j]] == labels_train[samples_test_k[i]]: - matches += 1 - - strength = min(strength, matches / (cluster_test_size * (cluster_test_size - 1.))) - - return strength + return (pairs_matching[nz] / pairs_total[nz]).min() From ae59b7dececd3933fe31e3f3ff689c0dcf300020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=B6lsterl?= Date: Tue, 5 Jan 2021 13:18:12 +0100 Subject: [PATCH 3/5] Use sparse contingency matrix --- sklearn/metrics/cluster/_unsupervised.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/cluster/_unsupervised.py b/sklearn/metrics/cluster/_unsupervised.py index b2ae7f1c2707c..80ddb25441cb3 100644 --- a/sklearn/metrics/cluster/_unsupervised.py +++ b/sklearn/metrics/cluster/_unsupervised.py @@ -365,6 +365,13 @@ def davies_bouldin_score(X, labels): return np.mean(scores) +def _non_zero_add(sparse_matrix, value): + """Add value to non-zero entries of a sparse matrix""" + M = sparse_matrix.copy() + M.data += value + return M + + def prediction_strength_score(labels_train, labels_test): """Compute the prediction strength score. @@ -409,9 +416,10 @@ def prediction_strength_score(labels_train, labels_test): if n_clusters == 1: return 1.0 # by definition - C = contingency_matrix(labels_train, labels_test) - pairs_matching = (C * (C - 1) / 2).sum(axis=0) - M = C.sum(axis=0) + C = contingency_matrix(labels_train, labels_test, sparse=True) + Cp = C.multiply(_non_zero_add(C, -1)) / 2 + pairs_matching = np.asarray(Cp.sum(axis=0)).ravel() + M = np.asarray(C.sum(axis=0)).ravel() pairs_total = (M * (M - 1) / 2) nz = pairs_total.nonzero()[0] From e8875908827a5b04047f9b1041b3e0af0c614a4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=B6lsterl?= Date: Tue, 5 Jan 2021 15:50:42 +0100 Subject: [PATCH 4/5] Add PredictionStrengthGridSearchCV class --- sklearn/model_selection/__init__.py | 2 + sklearn/model_selection/_search.py | 316 +++++++++++++++++-- sklearn/model_selection/_validation.py | 135 +++++--- sklearn/model_selection/tests/test_search.py | 136 ++++++++ 4 files changed, 525 insertions(+), 64 deletions(-) diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py index 897183414b5a6..12c6eabc479bf 100644 --- a/sklearn/model_selection/__init__.py +++ b/sklearn/model_selection/__init__.py @@ -29,6 +29,7 @@ from ._search import RandomizedSearchCV from ._search import ParameterGrid from ._search import ParameterSampler +from ._search import PredictionStrengthGridSearchCV from ._search import fit_grid_point if typing.TYPE_CHECKING: @@ -54,6 +55,7 @@ 'ParameterGrid', 'ParameterSampler', 'PredefinedSplit', + 'PredictionStrengthGridSearchCV', 'RandomizedSearchCV', 'ShuffleSplit', 'StratifiedKFold', diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 213204b50c2a7..7beaac3874d6d 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -792,32 +792,10 @@ def evaluate_candidates(candidate_params, cv=None, " totalling {2} fits".format( n_splits, n_candidates, n_candidates * n_splits)) - out = parallel(delayed(_fit_and_score)(clone(base_estimator), - X, y, - train=train, test=test, - parameters=parameters, - split_progress=( - split_idx, - n_splits), - candidate_progress=( - cand_idx, - n_candidates), - **fit_and_score_kwargs) - for (cand_idx, parameters), - (split_idx, (train, test)) in product( - enumerate(candidate_params), - enumerate(cv.split(X, y, groups)))) - - if len(out) < 1: - raise ValueError('No fits were performed. ' - 'Was the CV iterator empty? ' - 'Were there no candidates?') - elif len(out) != n_candidates * n_splits: - raise ValueError('cv.split and cv.get_n_splits returned ' - 'inconsistent results. Expected {} ' - 'splits, got {}' - .format(n_splits, - len(out) // n_candidates)) + out = self._fit_parallel( + parallel, base_estimator, X, y, groups, cv, n_splits, + candidate_params, fit_and_score_kwargs + ) # For callable self.scoring, the return type is only know after # calling. If the return type is a dictionary, the error scores @@ -864,8 +842,8 @@ def evaluate_candidates(candidate_params, cv=None, self.best_index_ >= len(results["params"])): raise IndexError('best_index_ index out of range') else: - self.best_index_ = results["rank_test_%s" - % refit_metric].argmin() + self.best_index_ = self._get_best_parameters( + results, refit_metric) self.best_score_ = results["mean_test_%s" % refit_metric][ self.best_index_] self.best_params_ = results["params"][self.best_index_] @@ -891,6 +869,42 @@ def evaluate_candidates(candidate_params, cv=None, return self + def _get_best_parameters(self, results, refit_metric): + return results["rank_test_%s" % refit_metric].argmin() + + def _fit_parallel(self, parallel, base_estimator, X, y, groups, + cv, n_splits, candidate_params, fit_and_score_kwargs): + n_candidates = len(candidate_params) + + out = parallel(delayed(_fit_and_score)(clone(base_estimator), + X, y, + train=train, test=test, + parameters=parameters, + split_progress=( + split_idx, + n_splits), + candidate_progress=( + cand_idx, + n_candidates), + **fit_and_score_kwargs) + for (cand_idx, parameters), + (split_idx, (train, test)) in product( + enumerate(candidate_params), + enumerate(cv.split(X, y, groups)))) + + if len(out) < 1: + raise ValueError('No fits were performed. ' + 'Was the CV iterator empty? ' + 'Were there no candidates?') + elif len(out) != n_candidates * n_splits: + raise ValueError('cv.split and cv.get_n_splits returned ' + 'inconsistent results. Expected {} ' + 'splits, got {}' + .format(n_splits, + len(out) // n_candidates)) + + return out + def _format_results(self, candidate_params, n_splits, out, more_results=None): n_candidates = len(candidate_params) @@ -1288,6 +1302,252 @@ def _run_search(self, evaluate_candidates): evaluate_candidates(ParameterGrid(self.param_grid)) +class PredictionStrengthGridSearchCV(GridSearchCV): + """Exhaustive search to determine optimal number of clusters for a + clusterer. + + The optimal number of clusters is determined by cross-validated grid-search + over a parameter grid. Performance is evaluated using + :func:`sklearn.metrics.prediction_strength_score`. The clusters in the + training set are once predicted after training on the same training set + (1), and once after training on an independent test set (2). A stable + clustering has the property that observation pairs that are assigned to the + same cluster in (1) are also assigned to the same cluster in (2). + The optimal number of clusters k is the largest k such that the + corresponding prediction strength (averaged across all cross-validation + splits) is above some threshold. + + In contrast to :class:`sklearn.model_selection.GridSearchCV`, the splits + obtained from a cross-validation generator are used interchangeably for + training and for testing, and performance is solely evaluated based on + :func:`sklearn.metrics.prediction_strength_score`. + + Parameters + ---------- + estimator : estimator object. + This is assumed to implement the scikit-learn estimator interface. + + param_grid : dict or list of dictionaries + Dictionary with parameters names (string) as keys and lists of + parameter settings to try as values, or a list of such + dictionaries, in which case the grids spanned by each dictionary + in the list are explored. Must contain `n_clusters`. + + threshold : float, default=0.8 + The optimal number of clusters k is the largest k such that the + corresponding prediction strength is above the given threshold. + The threshold must be greater 0 and less or equal 1. + + fit_params : dict, optional + Parameters to pass to the fit method. + + n_jobs : int, default=1 + Number of jobs to run in parallel. + + pre_dispatch : int, or string, optional + Controls the number of jobs that get dispatched during parallel + execution. Reducing this number can be useful to avoid an + explosion of memory consumption when more jobs get dispatched + than CPUs can process. This parameter can be: + + - None, in which case all the jobs are immediately + created and spawned. Use this for lightweight and + fast-running jobs, to avoid delays due to on-demand + spawning of the jobs + + - An int, giving the exact number of total jobs that are + spawned + + - A string, giving an expression as a function of n_jobs, + as in '2*n_jobs' + + iid : boolean, default=True + If True, the data is assumed to be identically distributed across + the folds, and the loss minimized is the total loss per sample, + and not the mean loss across the folds. + + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy. + Possible inputs for cv are: + - None, to use the default 3-fold cross validation, + - integer, to specify the number of folds in a `KFold`, + - An object to be used as a cross-validation generator. + - An iterable yielding train, test splits. + + Refer :ref:`User Guide ` for the various + cross-validation strategies that can be used here. + + refit : boolean, default=True + Refit the best estimator with the entire dataset. + If "False", it is impossible to make predictions using + this GridSearchCV instance after fitting. + + verbose : integer + Controls the verbosity: the higher, the more messages. + + error_score : 'raise' (default) or numeric + Value to assign to the score if an error occurs in estimator fitting. + If set to 'raise', the error is raised. If a numeric value is given, + FitFailedWarning is raised. This parameter does not affect the refit + step, which will always raise the error. + + Attributes + ---------- + cv_results_ : dict of numpy (masked) ndarrays + A dict with keys as column headers and values as columns, that can be + imported into a pandas ``DataFrame``. See :class:`GridSearchCV`: for + details. + + best_estimator_ : estimator + Estimator that was chosen by the search, i.e. estimator + which gave highest score (or smallest loss if specified) + on the left out data. Not available if refit=False. + + best_score_ : float + Score of best_estimator on the left out data. + + best_params_ : dict + Parameter setting that gave the best results on the hold out data. + + best_index_ : int + The index (of the ``cv_results_`` arrays) which corresponds to the best + candidate parameter setting. + + The dict at ``search.cv_results_['params'][search.best_index_]`` gives + the parameter setting for the best model, that gives the highest + mean score (``search.best_score_``). + + n_splits_ : int + The number of cross-validation splits (folds/iterations). + + See Also + --------- + :func:`sklearn.metrics.prediction_strength_score`: + Prediction strength metric that is used during grid search. + + References + ---------- + .. [1] `Robert Tibshirani and Guenther Walther (2005). "Cluster Validation + by Prediction Strength". Journal of Computational and Graphical Statistics, + 14(3), 511-528. _` + """ + @_deprecate_positional_args + def __init__(self, estimator, param_grid, *, threshold=0.8, + n_jobs=None, refit=True, cv=None, verbose=0, + pre_dispatch='2*n_jobs', error_score=np.nan): + super().__init__( + estimator=estimator, param_grid=param_grid, + scoring=None, + n_jobs=n_jobs, refit=refit, cv=cv, verbose=verbose, + pre_dispatch=pre_dispatch, error_score=error_score, + return_train_score=False) + self.threshold = threshold + + @_deprecate_positional_args + def fit(self, X, *, groups=None, **fit_params): + """Run fit with all sets of parameters. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + Training vector, where n_samples is the number of samples and + n_features is the number of features. + + groups : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + **fit_params : dict of string -> object + Parameters passed to the ``fit`` method of the estimator + """ + if not np.isfinite(self.threshold): + raise ValueError("threshold must be finite") + if self.threshold <= 0 or self.threshold > 1: + raise ValueError("threshold must be in the interval (0, 1]," + " but was %r" % self.threshold) + if "n_clusters" not in self.param_grid: + raise ValueError("param_grid must contain n_clusters") + + return super().fit(X, y=None, groups=groups, **fit_params) + + def _fit_parallel(self, parallel, base_estimator, X, y, groups, + cv, n_splits, candidate_params, fit_and_score_kwargs): + n_candidates = len(candidate_params) + + # (i) fit X_test, predict X_train + # (ii) fit X_train, predict X_train + # (iii) call prediction_strength_score + out_a = parallel(delayed(_fit_and_score)(clone(base_estimator), + X, y, + train=train, test=test, + parameters=parameters, + split_progress=( + split_idx, + n_splits), + candidate_progress=( + cand_idx, + n_candidates), + use_prediction_strength=True, + **fit_and_score_kwargs) + for (cand_idx, parameters), + (split_idx, (train, test)) in product( + enumerate(candidate_params), + enumerate(cv.split(X, y, groups)))) + + # swap train and test + # (i) fit X_train, predict X_test + # (ii) fit X_test, predict X_test + # (iii) call prediction_strength_score + out_b = parallel(delayed(_fit_and_score)(clone(base_estimator), + X, y, + train=test, test=train, + parameters=parameters, + split_progress=( + split_idx, + n_splits), + candidate_progress=( + cand_idx, + n_candidates), + use_prediction_strength=True, + **fit_and_score_kwargs) + for (cand_idx, parameters), + (split_idx, (train, test)) in product( + enumerate(candidate_params), + enumerate(cv.split(X, y, groups)))) + + def _combine(ret1, ret2): + # average prediction strength score + score1, score2 = ret1["test_scores"], ret2["test_scores"] + ret = {"test_scores": 0.5 * (score1 + score2)} + # sum remaining values (n_samples, fit_time, score_time) + for key in ret1.keys(): + if key != "test_scores": + ret[key] = ret1[key] + ret2[key] + return ret + + out = [_combine(*v) for v in zip(out_a, out_b)] + return out + + def _get_best_parameters(self, results, refit_metric): + best_index = None + test_scores = results["mean_test_%s" % refit_metric] + order = np.argsort(results["param_n_clusters"]) + for i in order: + score = test_scores[i] + if score >= self.threshold: + best_index = i + + if best_index is None: + warnings.warn( + ("No parameter exceeds threshold %f, try decreasing it. " + "Falling back to parameters with highest score." + % self.threshold), + stacklevel=3) + best_index = super()._get_best_parameters(results, refit_metric) + + return best_index + + class RandomizedSearchCV(BaseSearchCV): """Randomized search on hyper parameters. diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 63f68a8e30738..bc605e61d6c6f 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -28,6 +28,7 @@ from ..utils.fixes import delayed from ..utils.metaestimators import _safe_split from ..metrics import check_scoring +from ..metrics import prediction_strength_score from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer from ..exceptions import FitFailedWarning, NotFittedError from ._split import check_cv @@ -451,7 +452,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, return_parameters=False, return_n_test_samples=False, return_times=False, return_estimator=False, split_progress=None, candidate_progress=None, - error_score=np.nan): + error_score=np.nan, use_prediction_strength=False): """Fit estimator and compute scores for a given dataset split. @@ -580,19 +581,68 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, estimator = estimator.set_params(**cloned_parameters) - start_time = time.time() - X_train, y_train = _safe_split(estimator, X, y, train) X_test, y_test = _safe_split(estimator, X, y, test, train) - result = {} + if use_prediction_strength: + (fit_failed, test_scores, + score_time, fit_time) = _prediction_strength_fit_and_score( + estimator, X_train, X_test, y_train, y_test, + fit_params, error_score + ) + else: + (fit_failed, test_scores, train_scores, + score_time, fit_time) = _default_fit_and_score( + estimator, X_train, X_test, y_train, y_test, scorer, + fit_params, return_train_score, error_score + ) + result = {"fit_failed": fit_failed, "test_scores": test_scores} + + if verbose > 1: + total_time = score_time + fit_time + end_msg = f"[CV{progress_msg}] END " + result_msg = params_msg + (";" if params_msg else "") + if verbose > 2 and isinstance(test_scores, dict): + for scorer_name in sorted(test_scores): + result_msg += f" {scorer_name}: (" + if return_train_score: + scorer_scores = train_scores[scorer_name] + result_msg += f"train={scorer_scores:.3f}, " + result_msg += f"test={test_scores[scorer_name]:.3f})" + result_msg += f" total time={logger.short_format_time(total_time)}" + + # Right align the result_msg + end_msg += "." * (80 - len(end_msg) - len(result_msg)) + end_msg += result_msg + print(end_msg) + + if return_train_score: + result["train_scores"] = train_scores + if return_n_test_samples: + result["n_test_samples"] = _num_samples(X_test) + if return_times: + result["fit_time"] = fit_time + result["score_time"] = score_time + if return_parameters: + result["parameters"] = parameters + if return_estimator: + result["estimator"] = estimator + return result + + +def _default_fit_and_score(estimator, X_train, X_test, y_train, y_test, + scorer, fit_params, return_train_score, + error_score): + start_time = time.time() + train_scores = None + try: if y_train is None: estimator.fit(X_train, **fit_params) else: estimator.fit(X_train, y_train, **fit_params) - except Exception as e: + except Exception: # Note fit time as time until error fit_time = time.time() - start_time score_time = 0.0 @@ -612,9 +662,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, "Details: \n%s" % (error_score, format_exc()), FitFailedWarning) - result["fit_failed"] = True + fit_failed = True else: - result["fit_failed"] = False + fit_failed = False fit_time = time.time() - start_time test_scores = _score(estimator, X_test, y_test, scorer, error_score) @@ -624,37 +674,50 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, estimator, X_train, y_train, scorer, error_score ) - if verbose > 1: - total_time = score_time + fit_time - end_msg = f"[CV{progress_msg}] END " - result_msg = params_msg + (";" if params_msg else "") - if verbose > 2 and isinstance(test_scores, dict): - for scorer_name in sorted(test_scores): - result_msg += f" {scorer_name}: (" - if return_train_score: - scorer_scores = train_scores[scorer_name] - result_msg += f"train={scorer_scores:.3f}, " - result_msg += f"test={test_scores[scorer_name]:.3f})" - result_msg += f" total time={logger.short_format_time(total_time)}" + return fit_failed, test_scores, train_scores, score_time, fit_time - # Right align the result_msg - end_msg += "." * (80 - len(end_msg) - len(result_msg)) - end_msg += result_msg - print(end_msg) - result["test_scores"] = test_scores - if return_train_score: - result["train_scores"] = train_scores - if return_n_test_samples: - result["n_test_samples"] = _num_samples(X_test) - if return_times: - result["fit_time"] = fit_time - result["score_time"] = score_time - if return_parameters: - result["parameters"] = parameters - if return_estimator: - result["estimator"] = estimator - return result +def _prediction_strength_fit_and_score(estimator, X_train, X_test, + y_train, y_test, fit_params, + error_score): + start_time = time.time() + + try: + if y_train is None: + estimator.fit(X_train, **fit_params) + predict_train = estimator.predict(X_train) + + estimator.fit(X_test, **fit_params) + predict_test = estimator.predict(X_train) + else: + estimator.fit(X_train, y_train, **fit_params) + predict_train = estimator.predict(X_train, y_train) + + estimator.fit(X_test, y_test, **fit_params) + predict_test = estimator.predict(X_train) + + except Exception: + # Note fit time as time until error + fit_time = time.time() - start_time + score_time = 0.0 + if error_score == 'raise': + raise + elif isinstance(error_score, numbers.Number): + test_scores = error_score + warnings.warn("Estimator fit failed. The score on this train-test" + " partition for these parameters will be set to %f. " + "Details: \n%s" % + (error_score, format_exc()), + FitFailedWarning) + fit_failed = True + else: + fit_failed = False + + fit_time = time.time() - start_time + test_scores = prediction_strength_score(predict_test, predict_train) + score_time = time.time() - start_time - fit_time + + return fit_failed, test_scores, score_time, fit_time def _score(estimator, X_test, y_test, scorer, error_score="raise"): diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index af2ca92aee26b..77cba4946deee 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -51,6 +51,7 @@ from sklearn.model_selection import RandomizedSearchCV from sklearn.model_selection import ParameterGrid from sklearn.model_selection import ParameterSampler +from sklearn.model_selection import PredictionStrengthGridSearchCV from sklearn.model_selection._search import BaseSearchCV from sklearn.model_selection._validation import FitFailedWarning @@ -2117,3 +2118,138 @@ def test_search_cv_using_minimal_compatible_estimator(SearchCV, Predictor): else: assert_allclose(y_pred, y.mean()) assert search.score(X, y) == pytest.approx(r2_score(y, y_pred)) + + +def test_prediction_strength_fit_easy(): + X, y = make_blobs(600, 10, centers=4, random_state=1) + + pscv = PredictionStrengthGridSearchCV( + estimator=KMeans(random_state=1), + param_grid={'n_clusters': [2, 3, 4, 5, 6, 7, 8]}, + threshold=0.9, + cv=2) + pscv.fit(X) + + assert pscv.best_params_ == {'n_clusters': 4} + assert pscv.best_estimator_.get_params()['n_clusters'] == 4 + assert pscv.best_score_ > 0.9 + assert pscv.n_splits_ == 2 + + pscv_rev = PredictionStrengthGridSearchCV( + estimator=KMeans(random_state=1), + param_grid={'n_clusters': [8, 7, 6, 5, 4, 3, 2]}, + threshold=0.9, + cv=2) + pscv_rev.fit(X) + assert pscv_rev.best_params_ == pscv.best_params_ + assert_almost_equal(pscv_rev.best_score_, pscv.best_score_) + + +def test_prediction_strength_fit_hard(): + X, y = make_blobs(600, 10, centers=3, cluster_std=10., + center_box=(-5., 5.), random_state=1) + + pscv = PredictionStrengthGridSearchCV( + estimator=KMeans(random_state=1), + param_grid={'n_clusters': [2, 3, 4, 5, 6, 7, 8]}, + threshold=0.7, + cv=5) + + with pytest.warns(UserWarning): + pscv.fit(X) + assert pscv.best_params_ == {'n_clusters': 2} + assert pscv.best_estimator_.get_params()['n_clusters'] == 2 + assert 0.6 > pscv.best_score_ + + +def test_prediction_strength_fit_verbose(): + X, y = make_blobs() + pscv = PredictionStrengthGridSearchCV( + estimator=KMeans(random_state=1), + param_grid={'n_clusters': [2, 3, 4, 5, 6, 7, 8]}, + verbose=10, error_score='raise') + + pscv.fit(X) + + +class FailingClusterer(BaseEstimator): + """Classifier that raises a ValueError on fit()""" + + FAILING_PARAMETER = 2 + + def __init__(self, n_clusters=None): + self.n_clusters = n_clusters + + def fit(self, X): + if self.n_clusters == FailingClusterer.FAILING_PARAMETER: + raise ValueError("Failing clusterer failed as required") + + def predict(self, X): + return np.zeros(X.shape[0]) + + def score(self, X): + return np.nan + + +def test_prediction_strength_fit_failing_clusterer(): + X, y = make_blobs() + pscv = PredictionStrengthGridSearchCV( + estimator=FailingClusterer(5), + param_grid={'n_clusters': [2, 3, 4, 5, 6, 7, 8]}, + refit=False, error_score='raise') + + # FailingClusterer issues a ValueError so this is what we look for. + with pytest.raises(ValueError, + match="Failing clusterer failed as required"): + pscv.fit(X) + + +def test_prediction_strength_threshold_range(): + X, y = make_blobs() + pscv = PredictionStrengthGridSearchCV( + estimator=KMeans(random_state=1), + param_grid={'n_clusters': [2, 3, 4, 5, 6, 7, 8]}, + threshold=0.0, + cv=5) + with pytest.raises(ValueError, + match=r"threshold must be in the interval \(0, 1\], " + r"but was"): + pscv.fit(X) + + pscv.set_params(threshold=-0.1) + with pytest.raises(ValueError, + match=r"threshold must be in the interval \(0, 1\], " + r"but was"): + pscv.fit(X) + + pscv.set_params(threshold=10) + with pytest.raises(ValueError, + match=r"threshold must be in the interval \(0, 1\], " + r"but was"): + pscv.fit(X) + + pscv.set_params(threshold=float('nan')) + with pytest.raises(ValueError, + match="threshold must be finite"): + pscv.fit(X) + + pscv.set_params(threshold=float('-inf')) + with pytest.raises(ValueError, + match="threshold must be finite"): + pscv.fit(X) + + pscv.set_params(threshold=float('inf')) + with pytest.raises(ValueError, + match="threshold must be finite"): + pscv.fit(X) + + +def test_prediction_strength_param_grid(): + X, y = make_blobs() + pscv = PredictionStrengthGridSearchCV( + estimator=KMeans(random_state=1), + param_grid={'n_init': [2, 5, 10]}, + cv=5) + with pytest.raises(ValueError, + match="param_grid must contain n_clusters"): + pscv.fit(X) From c81d564d869fe3105b25f2cecd0e760dd46597b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=B6lsterl?= Date: Tue, 5 Jan 2021 15:54:16 +0100 Subject: [PATCH 5/5] Add example showcasing PredictionStrengthGridSearchCV --- .../cluster/plot_prediction_strength_cv.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 examples/cluster/plot_prediction_strength_cv.py diff --git a/examples/cluster/plot_prediction_strength_cv.py b/examples/cluster/plot_prediction_strength_cv.py new file mode 100644 index 0000000000000..256ba0419425b --- /dev/null +++ b/examples/cluster/plot_prediction_strength_cv.py @@ -0,0 +1,92 @@ +""" +===================================================================== +Selecting the number of clusters with prediction strength grid search +===================================================================== + +Prediction strength is a metric that measures the stability of a clustering +algorithm and can be used to determine an optimal number of clusters without +knowing the true cluster assignments. + +First, one splits the data into two parts (A) and (B). +One obtains two cluster assignments, the first one using the centroids +derived from the subset (A), and the second one using the centroids +from the subset (B). Prediction strength measures the proportion of observation +pairs that are assigned to the same clusters according to both clusterings. +The overall prediction strength is the minimum of this quantity over all +predicted clusters. + +By varying the desired number of clusters from low to high, we can choose the +highest number of clusters for which the prediction strength exceeds some +threshold. This is precisely how +:class:`sklearn.model_selection.PredictionStrengthGridSearchCV` operates, +as illustrated in the example below. We evaluate ``n_clusters`` in the range +2 to 8 via 5-fold cross-validation. While the average prediction strength +is high for 2, 3, and 4, it sharply drops below the threshold of 0.8 if +``n_clusters`` is 5 or higher. Therefore, we can conclude that the optimal +number of clusters is 4. +""" + +import matplotlib.pyplot as plt +import numpy as np +from scipy.stats import sem + +from sklearn.datasets import make_blobs +from sklearn.cluster import KMeans +from sklearn.model_selection import PredictionStrengthGridSearchCV +from sklearn.model_selection import KFold + +# Generating the sample data from make_blobs +# This particular setting has one distinct cluster and 3 clusters placed close +# together. +X, y = make_blobs(n_samples=500, + n_features=2, + centers=4, + cluster_std=1, + center_box=(-10.0, 10.0), + shuffle=True, + random_state=1) # For reproducibility + +# Define list of values for n_clusters we want to explore +range_n_clusters = [2, 3, 4, 5, 6, 7, 8] +param_grid = {'n_clusters': range_n_clusters} + +# Determine optimal choice of n_clusters using 5-fold cross-validation. +# The optimal number of clusters k is the largest k such that the +# corresponding prediction strength is above some threshold. +# Tibshirani and Guenther suggest a threshold in the range 0.8 to 0.9 +# for well separated clusters. +clusterer = KMeans(random_state=10) +n_splits = 5 +grid_search = PredictionStrengthGridSearchCV(clusterer, threshold=0.8, + param_grid=param_grid, + cv=KFold(n_splits)) +grid_search.fit(X) + +# Retrieve the best configuration +print(grid_search.best_params_, grid_search.best_score_) + +# Retrieve the results stored in the cv_results_ attribute +n_parameters = len(range_n_clusters) +param_n_clusters = grid_search.cv_results_["param_n_clusters"] +mean_test_score = grid_search.cv_results_["mean_test_score"] + +# plot average prediction strength for each value for n_clusters +points = np.empty((n_parameters, 2), dtype=np.float_) +for i, values in enumerate(zip(param_n_clusters, mean_test_score)): + points[i, :] = values +plt.plot(points[:, 0], points[:, 1], marker='o', markerfacecolor='none') +plt.xlabel("n_clusters") +plt.ylabel("average prediction strength") + +# plot the standard error of the prediction strength as error bars +test_score_keys = ["split%d_test_score" % split_i + for split_i in range(n_splits)] +test_scores = [grid_search.cv_results_[key] for key in test_score_keys] +se = np.fromiter((sem(values) for values in zip(*test_scores)), + dtype=np.float_) +plt.errorbar(points[:, 0], points[:, 1], se) + +plt.hlines(grid_search.threshold, min(range_n_clusters), max(range_n_clusters), + linestyles='dashed') + +plt.show()