8000 FIX allow object dtype arrays in clustering metrics (#15535) · rasbt/scikit-learn@d4e0826 · GitHub
[go: up one dir, main page]

Skip to content

Commit d4e0826

Browse files
amuelleradrinjalali
authored andcommitted
FIX allow object dtype arrays in clustering metrics (scikit-learn#15535)
* allow object dtype arrays in clustering metrics * pep8
1 parent 57b3029 commit d4e0826

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

sklearn/metrics/cluster/_supervised.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ def check_clusterings(labels_true, labels_pred):
4343
The predicted labels.
4444
"""
4545
labels_true = check_array(
46-
labels_true, ensure_2d=False, ensure_min_samples=0
46+
labels_true, ensure_2d=False, ensure_min_samples=0, dtype=None,
4747
)
4848
labels_pred = check_array(
49-
labels_pred, ensure_2d=False, ensure_min_samples=0
49+
labels_pred, ensure_2d=False, ensure_min_samples=0, dtype=None,
5050
)
5151

5252
# input checks

sklearn/metrics/cluster/tests/test_common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ def generate_formats(y):
161161
y = np.array(y)
162162
yield y, 'array of ints'
163163
yield y.tolist(), 'list of ints'
164-
yield [str(x) for x in y.tolist()], 'list of strs'
164+
yield [str(x) + "-a" for x in y.tolist()], 'list of strs'
165+
yield (np.array([str(x) + "-a" for x in y.tolist()], dtype=object),
166+
'array of strs')
165167
yield y - 1, 'including negative ints'
166168
yield y + 1, 'strictly positive ints'
167169

0 commit comments

Comments
 (0)
0