8000 TST Extend tests for `scipy.sparse.*array` in `sklearn/cluster/tests/test_bisect_k_means.py` by Kislovskiy · Pull Request #27099 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

TST Extend tests for scipy.sparse.*array in sklearn/cluster/tests/test_bisect_k_means.py #27099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions sklearn/cluster/tests/test_bisect_k_means.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
import pytest
import scipy.sparse as sp

from sklearn.cluster import BisectingKMeans
from sklearn.metrics import v_measure_score
from sklearn.utils._testing import assert_allclose, assert_array_equal
from sklearn.utils.fixes import CSR_CONTAINERS


@pytest.mark.parametrize("bisecting_strategy", ["biggest_inertia", "largest_cluster"])
Expand Down Expand Up @@ -33,7 +33,8 @@ def test_three_clusters(bisecting_strategy, init):
assert_allclose(v_measure_score(expected_labels, bisect_means.labels_), 1.0)


def test_sparse():
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test_sparse(csr_container):
"""Test Bisecting K-Means with sparse data.

Checks if labels and centers are the same between dense and sparse.
Expand All @@ -43,7 +44,7 @@ def test_sparse():

X = rng.rand(20, 2)
X[X < 0.8] = 0
X_csr = sp.csr_matrix(X)
X_csr = csr_container(X)

bisect_means = BisectingKMeans(n_clusters=3, random_state=0)

Expand Down Expand Up @@ -84,48 +85,48 @@ def test_one_cluster():
assert_allclose(bisect_means.cluster_centers_, X.mean(axis=0).reshape(1, -1))


@pytest.mark.parametrize("is_sparse", [True, False])
def test_fit_predict(is_sparse):
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS + [None])
def test_fit_predict(csr_container):
"""Check if labels from fit(X) method are same as from fit(X).predict(X)."""
rng = np.random.RandomState(0)

X = rng.rand(10, 2)

if is_sparse:
if csr_container is not None:
X[X < 0.8] = 0
X = sp.csr_matrix(X)
X = csr_container(X)

bisect_means = BisectingKMeans(n_clusters=3, random_state=0)
bisect_means.fit(X)

assert_array_equal(bisect_means.labels_, bisect_means.predict(X))


@pytest.mark.parametrize("is_sparse", [True, False])
def test_dtype_preserved(is_sparse, global_dtype):
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS + [None])
def test_dtype_preserved(csr_container, global_dtype):
"""Check that centers dtype is the same as input data dtype."""
rng = np.random.RandomState(0)
X = rng.rand(10, 2).astype(global_dtype, copy=False)

if is_sparse:
if csr_container is not None:
X[X < 0.8] = 0
X = sp.csr_matrix(X)
X = csr_container(X)

km = BisectingKMeans(n_clusters=3, random_state=0)
km.fit(X)

assert km.cluster_centers_.dtype == global_dtype


@pytest.mark.parametrize("is_sparse", [True, False])
def test_float32_float64_equivalence(is_sparse):
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS + [None])
def test_float32_float64_equivalence(csr_container):
"""Check that the results are the same between float32 and float64."""
rng = np.random.RandomState(0)
X = rng.rand(10, 2)

if is_sparse:
if csr_container is not None:
X[X < 0.8] = 0
X = sp.csr_matrix(X)
X = csr_container(X)

km64 = BisectingKMeans(n_clusters=3, random_state=0).fit(X)
km32 = BisectingKMeans(n_clusters=3, random_state=0).fit(X.astype(np.float32))
Expand Down
4 changes: 2 additions & 2 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,9 +1002,9 @@ def _check_large_sparse(X, accept_large_sparse=False):
"""Raise a ValueError if X has 64bit indices and accept_large_sparse=False"""
if not accept_large_sparse:
supported_indices = ["int32"]
if X.getformat() == "coo":
if X.format == "coo":
index_keys = ["col", "row"]
elif X.getformat() in ["csr", "csc", "bsr"]:
elif X.format in ["csr", "csc", "bsr"]:
index_keys = ["indices", "indptr"]
else:
return
Expand Down
0