-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG+2] Raise ValueError for metrics.cluster.supervised with too many classes #5445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,7 +44,7 @@ def check_clusterings(labels_true, labels_pred): | |
return labels_true, labels_pred | ||
|
||
|
||
def contingency_matrix(labels_true, labels_pred, eps=None): | ||
def contingency_matrix(labels_true, labels_pred, eps=None, max_n_classes=5000): | ||
"""Build a contengency matrix describing the relationship between labels. | ||
|
||
Parameters | ||
|
@@ -60,6 +60,11 @@ def contingency_matrix(labels_true, labels_pred, eps=None): | |
matrix. This helps to stop NaN propagation. | ||
If ``None``, nothing is adjusted. | ||
|
||
max_n_classes : int, optional (default=5000) | ||
Maximal number of classeses handled for contingency_matrix. | ||
This help to avoid Memory error with regression target | ||
for mutual_information. | ||
|
||
Returns | ||
------- | ||
contingency: array, shape=[n_classes_true, n_classes_pred] | ||
10000
|
@@ -72,6 +77,14 @@ def contingency_matrix(labels_true, labels_pred, eps=None): | |
clusters, cluster_idx = np.unique(labels_pred, return_inverse=True) | ||
n_classes = classes.shape[0] | ||
n_clusters = clusters.shape[0] | ||
if n_classes > max_n_classes: | ||
raise ValueError("Too many classes for a clustering metric. If you " | ||
"want to increase the limit, pass parameter " | ||
"max_n_classes to the scoring function") | ||
if n_clusters > max_n_classes: | ||
raise ValueError("Too many clusters for a clustering metric. If you " | ||
"want to increase the limit, pass parameter " | ||
"max_n_classes to the scoring function") | ||
# Using coo_matrix to accelerate simple histogram calculation, | ||
# i.e. bins are consecutive integers | ||
# Currently, coo_matrix is faster than histogram2d for simple cases | ||
|
@@ -87,7 +100,7 @@ def contingency_matrix(labels_true, labels_pred, eps=None): | |
|
||
# clustering measures | ||
|
||
def adjusted_rand_score(labels_true, labels_pred): | ||
def adjusted_rand_score(labels_true, labels_pred, max_n_classes=5000): | ||
"""Rand index adjusted for chance | ||
|
||
The Rand Index computes a similarity measure between two clusterings | ||
|
@@ -119,6 +132,11 @@ def adjusted_rand_score(labels_true, labels_pred): | |
labels_pred : array, shape = [n_samples] | ||
Cluster labels to evaluate | ||
|
||
max_n_classes: int, optional (default=5000) | ||
Maximal number of classes handled by the adjusted_rand_score | ||
metric. Setting it too high can lead to MemoryError or OS | ||
freeze | ||
|
||
Returns | ||
------- | ||
ari : float | ||
|
@@ -180,7 +198,8 @@ def adjusted_rand_score(labels_true, labels_pred): | |
or classes.shape[0] == clusters.shape[0] == len(labels_true)): | ||
return 1.0 | ||
|
||
contingency = contingency_matrix(labels_true, labels_pred) | ||
contingency = contingency_matrix(labels_true, labels_pred, | ||
max_n_classes=max_n_classes) | ||
|
||
# Compute the ARI using the contingency data | ||
sum_comb_c = sum(comb2(n_c) for n_c in contingency.sum(axis=1)) | ||
|
@@ -192,7 +211,8 @@ def adjusted_rand_score(labels_true, labels_pred): | |
return ((sum_comb - prod_comb) / (mean_comb - prod_comb)) | ||
|
||
|
||
def homogeneity_completeness_v_measure(labels_true, labels_pred): | ||
def homogeneity_completeness_v_measure(labels_true, labels_pred, | ||
max_n_classes=5000): | ||
"""Compute the homogeneity and completeness and V-Measure scores at once | ||
|
||
Those metrics are based on normalized conditional entropy measures of | ||
|
@@ -226,6 +246,11 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred): | |
labels_pred : array, shape = [n_samples] | ||
cluster labels to evaluate | ||
|
||
max_n_classes: int, optional (default=5000) | ||
Maximal number of classes handled by the adjusted_rand_score | ||
metric. Setting it too high can lead to MemoryError or OS | ||
freeze | ||
|
||
Returns | ||
------- | ||
homogeneity: float | ||
|
@@ -251,7 +276,8 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred): | |
entropy_C = entropy(labels_true) | ||
entropy_K = entropy(labels_pred) | ||
|
||
MI = mutual_info_score(labels_true, labels_pred) | ||
MI = mutual_info_score(labels_true, labels_pred, | ||
max_n_classes=max_n_classes) | ||
|
||
homogeneity = MI / (entropy_C) if entropy_C else 1.0 | ||
completeness = MI / (entropy_K) if entropy_K else 1.0 | ||
|
@@ -265,7 +2 10000 91,7 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred): | |
return homogeneity, completeness, v_measure_score | ||
|
||
|
||
def homogeneity_score(labels_true, labels_pred): | ||
def homogeneity_score(labels_true, labels_pred, max_n_classes=5000): | ||
"""Homogeneity metric of a cluster labeling given a ground truth | ||
|
||
A clustering result satisfies homogeneity if all of its clusters | ||
|
@@ -289,6 +315,11 @@ def homogeneity_score(labels_true, labels_pred): | |
labels_pred : array, shape = [n_samples] | ||
cluster labels to evaluate | ||
|
||
max_n_classes: int, optional (default=5000) | ||
Maximal number of classes handled by the adjusted_rand_score | ||
metric. Setting it too high can lead to MemoryError or OS | ||
freeze | ||
|
||
Returns | ||
------- | ||
homogeneity: float | ||
|
@@ -336,10 +367,11 @@ def homogeneity_score(labels_true, labels_pred): | |
0.0... | ||
|
||
""" | ||
return homogeneity_completeness_v_measure(labels_true, labels_pred)[0] | ||
return homogeneity_completeness_v_measure(labels_true, labels_pred, | ||
max_n_classes)[0] | ||
|
||
|
||
def completeness_score(labels_true, labels_pred): | ||
def completeness_score(labels_true, labels_pred, max_n_classes=5000): | ||
"""Completeness metric of a cluster labeling given a ground truth | ||
|
||
A clustering result satisfies completeness if all the data points | ||
|
@@ -363,6 +395,11 @@ def completeness_score(labels_true, labels_pred): | |
labels_pred : array, shape = [n_samples] | ||
cluster labels to evaluate | ||
|
||
max_n_classes: int, optional (default=5000) | ||
Maximal number of classes handled by the adjusted_rand_score | ||
metric. Setting it too high can lead to MemoryError or OS | ||
freeze | ||
|
||
Returns | ||
------- | ||
completeness: float | ||
|
@@ -406,10 +443,11 @@ def completeness_score(labels_true, labels_pred): | |
0.0 | ||
|
||
""" | ||
return homogeneity_completeness_v_measure(labels_true, labels_pred)[1] | ||
return homogeneity_completeness_v_measure(labels_true, labels_pred, | ||
max_n_classes)[1] | ||
|
||
|
||
def v_measure_score(labels_true, labels_pred): | ||
def v_measure_score(labels_true, labels_pred, max_n_classes=5000): | ||
"""V-measure cluster labeling given a ground truth. | ||
|
||
This score is identical to :func:`normalized_mutual_info_score`. | ||
|
@@ -437,6 +475,11 @@ def v_measure_score(labels_true, labels_pred): | |
labels_pred : array, shape = [n_samples] | ||
cluster labels to evaluate | ||
|
||
max_n_classes: int, optional (default=5000) | ||
Maximal number of classes handled by the adjusted_rand_score | ||
metric. Setting it too high can lead to MemoryError or OS | ||
freeze | ||
|
||
Returns | ||
------- | ||
v_measure: float | ||
|
@@ -501,10 +544,12 @@ def v_measure_score(labels_true, labels_pred): | |
0.0... | ||
|
||
""" | ||
return homogeneity_completeness_v_measure(labels_true, labels_pred)[2] | ||
return homogeneity_completeness_v_measure(labels_true, labels_pred, | ||
max_n_classes)[2] | ||
|
||
|
||
def mutual_info_score(labels_true, labels_pred, contingency=None): | ||
def mutual_info_score(labels_true, labels_pred, contingency=None, | ||
max_n_classes=5000): | ||
"""Mutual Information between two clusterings | ||
|
||
The Mutual Information is a measure of the similarity between two labels of | ||
|
@@ -544,6 +589,11 @@ def mutual_info_score(labels_true, labels_pred, contingency=None): | |
If value is ``None``, it will be computed, otherwise the given value is | ||
used, with ``labels_true`` and ``labels_pred`` ignored. | ||
|
||
max_n_classes: int, optional (default=5000) | ||
Maximal number of classes handled by the mutual_info_score | ||
metric. Setting it too high can lead to MemoryError or OS | ||
freeze | ||
|
||
Returns | ||
------- | ||
mi: float | ||
|
@@ -556,7 +606,8 @@ def mutual_info_score(labels_true, labels_pred, contingency=None): | |
""" | ||
if contingency is None: | ||
labels_true, labels_pred = check_clusterings(labels_true, labels_pred) | ||
contingency = contingency_matrix(labels_true, labels_pred) | ||
contingency = contingency_matrix(labels_true, labels_pred, | ||
max_n_classes=max_n_classes) | ||
contingency = np.array(contingency, dtype='float') | ||
contingency_sum = np.sum(contingency) | ||
pi = np.sum(contingency, axis=1) | ||
|
@@ -575,7 +626,7 @@ def mutual_info_score(labels_true, labels_pred, contingency=None): | |
return mi.sum() | ||
|
||
|
||
def adjusted_mutual_info_score(labels_true, labels_pred): | ||
def adjusted_mutual_info_score(labels_true, labels_pred, max_n_classes=5000): | ||
"""Adjusted Mutual Information between two clusterings | ||
|
||
Adjusted Mutual Information (AMI) is an adjustment of the Mutual | ||
|
@@ -608,6 +659,11 @@ def adjusted_mutual_info_score(labels_true, labels_pred): | |
labels_pred : array, shape = [n_samples] | ||
A clustering of the data into disjoint subsets. | ||
|
||
max_n_classes: int, optional (default=5000) | ||
Maximal number of classes handled by the adjusted_rand_score | ||
metric. Setting it too high can lead to MemoryError or OS | ||
freeze | ||
|
||
Returns | ||
------- | ||
ami: float(upperlimited by 1.0) | ||
|
@@ -658,7 +714,8 @@ def adjusted_mutual_info_score(labels_true, labels_pred): | |
if (classes.shape[0] == clusters.shape[0] == 1 | ||
or classes.shape[0] == clusters.shape[0] == 0): | ||
return 1.0 | ||
contingency = contingency_matrix(labels_true, labels_pred) | ||
contingency = contingency_matrix(labels_true, labels_pred, | ||
max_n_classes=max_n_classes) | ||
contingency = np.array(contingency, dtype='float') | ||
# Calculate the MI for the two clusterings | ||
mi = mutual_info_score(labels_true, labels_pred, | ||
|
@@ -671,7 +728,7 @@ def adjusted_mutual_info_score(labels_true, labels_pred): | |
return ami | ||
|
||
|
||
def normalized_mutual_info_score(labels_true, labels_pred): | ||
def normalized_mutual_info_score(labels_true, labels_pred, max_n_classes=5000): | ||
"""Normalized Mutual Information between two clusterings | ||
|
||
Normalized Mutual Information (NMI) is an normalization of the Mutual | ||
|
@@ -701,6 +758,11 @@ def normalized_mutual_info_score(labels_true, labels_pred): | |
labels_pred : array, shape = [n_samples] | ||
A clustering of the data into disjoint subsets. | ||
|
||
max_n_classes: int, optional (default=5000) | ||
Maximal number of classes handled by the adjusted_rand_score | ||
metric. Setting it too high can lead to MemoryError or OS | ||
freeze | ||
|
||
Returns | ||
------- | ||
nmi: float | ||
|
@@ -739,7 +801,8 @@ def normalized_mutual_info_score(labels_true, labels_pred): | |
if (classes.shape[0] == clusters.shape[0] == 1 | ||
or classes.shape[0] == clusters.shape[0] == 0): | ||
return 1.0 | ||
contingency = contingency_matrix(labels_true, labels_pred) | ||
contingency = contingency_matrix(labels_true, labels_pred, | ||
max_n_classes=max_n_classes) | ||
contingency = np.array(contingency, dtype='float') | ||
# Calculate the MI for the two clusterings | ||
mi = mutual_info_score(labels_true, labels_pred, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -181,10 +181,14 @@ def test_exactly_zero_info_score(): | |
for i in np.logspace(1, 4, 4).astype(np.int): | ||
labels_a, labels_b = np.ones(i, dtype=np.int),\ | ||
np.arange(i, dtype=np.int) | ||
assert_equal(normalized_mutual_info_score(labels_a, labels_b), 0.0) | ||
assert_equal(v_measure_score(labels_a, labels_b), 0.0) | ||
assert_equal(adjusted_mutual_info_score(labels_a, labels_b), 0.0) | ||
assert_equal(normalized_mutual_info_score(labels_a, labels_b), 0.0) | ||
assert_equal(normalized_mutual_info_score(labels_a, labels_b, | ||
max_n_classes=1e4), 0.0) | ||
assert_equal(v_measure_score(labels_a, labels_b, | ||
max_n_classes=1e4), 0.0) | ||
assert_equal(adjusted_mutual_info_score(labels_a, labels_b, | ||
max_n_classes=1e4), 0.0) | ||
assert_equal(normalized_mutual_info_score(labels_a, labels_b, | ||
max_n_classes=1e4), 0.0) | ||
|
||
|
||
def test_v_measure_and_mutual_information(seed=36): | ||
|
@@ -196,3 +200,26 @@ def test_v_measure_and_mutual_information(seed=36): | |
assert_almost_equal(v_measure_score(labels_a, labels_b), | ||
2.0 * mutual_info_score(labels_a, labels_b) / | ||
(entropy(labels_a) + entropy(labels_b)), 0) | ||
|
||
|
||
def test_max_n_classes(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to create a local random number generator: rng = np.random.RandomState(seed=0) and use it after: labels_true = rng.random(53) |
||
rng = np.random.RandomState(seed=0) | ||
labels_true = rng.rand(53) | ||
labels_pred = rng.rand(53) | ||
labels_zero = np.zeros(53) | ||
labels_true[:2] = 0 | ||
labels_zero[:3] = 1 | ||
labels_pred[:2] = 0 | ||
for score_func in score_funcs: | ||
expected = ("Too many classes for a clustering metric. If you " | ||
"want to increase the limit, pass parameter " | ||
"max_n_classes to the scoring function") | ||
assert_raise_message(ValueError, expected, score_func, | ||
labels_true, labels_pred, | ||
max_n_classes=50) | ||
expected = ("Too many clusters for a clustering metric. If you " | ||
"want to increase the limit, pass parameter " | ||
"max_n_classes to the scoring function") | ||
assert_raise_message(ValueError, expected, score_func, | ||
labels_zero, labels_pred, | ||
max_n_classes=50) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You seems to be above column 79 hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a 79 line I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my mistake