8000 TST check for nan and inf + single sample for metrics cl… (#10830) · crankycoder/scikit-learn@a47e914 · GitHub
[go: up one dir, main page]

Skip to content

Commit a47e914

Browse files
glemaitrerth
authored andcommitted
TST check for nan and inf + single sample for metrics cl… (scikit-learn#10830)
1 parent ac72a48 commit a47e914

File tree

3 files changed

+44
-18
lines changed

3 files changed

+44
-18
lines changed

sklearn/metrics/cluster/supervised.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from scipy import sparse as sp
2222

2323
from .expected_mutual_info_fast import expected_mutual_information
24-
from ...utils.validation import check_array
24+
from ...utils.validation import check_array, check_consistent_length
2525
from ...utils.fixes import comb, _astype_copy_false
2626

2727

@@ -36,14 +36,18 @@ def check_clusterings(labels_true, labels_pred):
3636
3737
Parameters
3838
----------
39-
labels_true : int array, shape = [n_samples]
40-
The true labels
39+
labels_true : array-like of shape (n_samples,)
40+
The true labels.
4141
42-
labels_pred : int array, shape = [n_samples]
43-
The predicted labels
42+
labels_pred : array-like of shape (n_samples,)
43+
The predicted labels.
4444
"""
45-
labels_true = np.asarray(labels_true)
46-
labels_pred = np.asarray(labels_pred)
45+
labels_true = check_array(
46+
labels_true, ensure_2d=False, ensure_min_samples=0
47+
)
48+
labels_pred = check_array(
49+
labels_pred, ensure_2d=False, ensure_min_samples=0
50+
)
4751

4852
# input checks
4953
if labels_true.ndim != 1:
@@ -52,10 +56,8 @@ def check_clusterings(labels_true, labels_pred):
5256
if labels_pred.ndim != 1:
5357
raise ValueError(
5458
"labels_pred must be 1D: shape is %r" % (labels_pred.shape,))
55-
if labels_true.shape != labels_pred.shape:
56-
raise ValueError(
57-
"labels_true and labels_pred must have same size, got %d and %d"
58-
% (labels_true.shape[0], labels_pred.shape[0]))
59+
check_consistent_length(labels_true, labels_pred)
60+
5961
return labels_true, labels_pred
6062

6163

sklearn/metrics/cluster/tests/test_common.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ def test_normalized_output(metric_name):
126126
# 0.22 AMI and NMI changes
127127
@pytest.mark.filterwarnings('ignore::FutureWarning')
128128
@pytest.mark.parametrize(
129-
"metric_name",
130-
dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)
129+
"metric_name", dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)
131130
)
132131
def test_permute_labels(metric_name):
133132
# All clustering metrics do not change score due to permutations of labels
@@ -150,11 +149,10 @@ def test_permute_labels(metric_name):
150149
# 0.22 AMI and NMI changes
151150
@pytest.mark.filterwarnings('ignore::FutureWarning')
152151
@pytest.mark.parametrize(
153-
"metric_name",
154-
dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)
152+
"metric_name", dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)
155153
)
156154
# For all clustering metrics Input parameters can be both
157-
# in the form of arrays lists, positive, negetive or string
155+
# in the form of arrays lists, positive, negative or string
158156
def test_format_invariance(metric_name):
159157
y_true = [0, 0, 0, 0, 1, 1, 1, 1]
160158
y_pred = [0, 1, 2, 3, 4, 5, 6, 7]
@@ -183,3 +181,29 @@ def generate_formats(y):
183181
y_true_gen = generate_formats(y_true)
184182
for (y_true_fmt, fmt_name) in y_true_gen:
185183
assert score_1 == metric(X, y_true_fmt)
184+
185+
186+
@pytest.mark.parametrize("metric", SUPERVISED_METRICS.values())
187+
def test_single_sample(metric):
188+
# only the supervised metrics support single sample
189+
for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]:
190+
metric([i], [j])
191+
192+
193+
@pytest.mark.parametrize(
194+
"metric_name, metric_func",
195+
dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS).items()
196+
)
197+
def test_inf_nan_input(metric_name, metric_func):
198+
if metric_name in SUPERVISED_METRICS:
199+
invalids = [([0, 1], [np.inf, np.inf]),
200+
([0, 1], [np.nan, np.nan]),
201+
([0, 1], [np.nan, np.inf])]
202+
else:
203+
X = np.random.randint(10, size=(2, 10))
204+
invalids = [(X, [np.inf, np.inf]),
205+
(X, [np.nan, np.nan]),
206+
(X, [np.nan, np.inf])]
207+
with pytest.raises(ValueError, match='contains NaN, infinity'):
208+
for args in invalids:
209+
metric_func(*args)

sklearn/metrics/cluster/tests/test_supervised.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
@ignore_warnings(category=FutureWarning)
3535
def test_error_messages_on_wrong_input():
3636
for score_func in score_funcs:
37-
expected = ('labels_true and labels_pred must have same size,'
38-
' got 2 and 3')
37+
expected = (r'Found input variables with inconsistent numbers '
38+
r'of samples: \[2, 3\]')
3939
with pytest.raises(ValueError, match=expected):
4040
score_func([0, 1], [1, 1, 1])
4141

0 commit comments

Comments
 (0)
0