8000 MNT Improve robustness of sparse test in `HDBSCAN` (#26889) · punndcoder28/scikit-learn@640bfd5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 640bfd5

Browse files
Micky774punndcoder28
authored andcommitted
MNT Improve robustness of sparse test in HDBSCAN (scikit-learn#26889)
1 parent 05150c5 commit 640bfd5

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

sklearn/cluster/tests/test_hdbscan.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,22 +287,37 @@ def test_hdbscan_precomputed_non_brute(tree):
287287
def test_hdbscan_sparse():
288288
"""
289289
Tests that HDBSCAN works correctly when passing sparse feature data.
290+
Evaluates correctness by comparing against the same data passed as a dense
291+
array.
290292
"""
291-
sparse_X = sparse.csr_matrix(X)
292293

293-
labels = HDBSCAN().fit(sparse_X).labels_
294-
n_clusters = len(set(labels) - OUTLIER_SET)
294+
dense_labels = HDBSCAN().fit(X).labels_
295+
n_clusters = len(set(dense_labels) - OUTLIER_SET)
295296
assert n_clusters == 3
296297

297-
sparse_X_nan = sparse_X.copy()
298-
sparse_X_nan[0, 0] = np.nan
299-
labels = HDBSCAN().fit(sparse_X_nan).labels_
300-
n_clusters = len(set(labels) - OUTLIER_SET)
301-
assert n_clusters == 3
298+
_X_sparse = sparse.csr_matrix(X)
299+
X_sparse = _X_sparse.copy()
300+
sparse_labels = HDBSCAN().fit(X_sparse).labels_
301+
assert_array_equal(dense_labels, sparse_labels)
302+
303+
# Compare that the sparse and dense non-precomputed routines return the same labels
304+
# where the 0th observation contains the outlier.
305+
for outlier_val, outlier_type in ((np.inf, "infinite"), (np.nan, "missing")):
306+
X_dense = X.copy()
307+
X_dense[0, 0] = outlier_val
308+
dense_labels = HDBSCAN().fit(X_dense).labels_
309+
n_clusters = len(set(dense_labels) - OUTLIER_SET)
310+
assert n_clusters == 3
311+
assert dense_labels[0] == _OUTLIER_ENCODING[outlier_type]["label"]
312+
313+
X_sparse = _X_sparse.copy()
314+
X_sparse[0, 0] = outlier_val
315+
sparse_labels = HDBSCAN().fit(X_sparse).labels_
316+
assert_array_equal(dense_labels, sparse_labels)
302317

303318
msg = "Sparse data matrices only support algorithm `brute`."
304319
with pytest.raises(ValueError, match=msg):
305-
HDBSCAN(metric="euclidean", algorithm="balltree").fit(sparse_X)
320+
HDBSCAN(metric="euclidean", algorithm="balltree").fit(X_sparse)
306321

307322

308323
@pytest.mark.parametrize("algorithm", ALGORITHMS)

0 commit comments

Comments
 (0)
0