8000 TST Extend tests for `scipy.sparse.*array` in `sklearn/metrics/tests/… · REDVM/scikit-learn@e3f4299 · GitHub
[go: up one dir, main page]

Skip to content

Commit e3f4299

Browse files
Tialoglemaitre
authored andcommitted
TST Extend tests for scipy.sparse.*array in sklearn/metrics/tests/test_ranking.py (scikit-learn#27212)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent e62170d commit e3f4299

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

sklearn/metrics/tests/test_ranking.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
import pytest
66
from scipy import stats
7-
from scipy.sparse import csr_matrix
87

98
from sklearn import datasets, svm
109
from sklearn.datasets import make_multilabel_classification
@@ -36,6 +35,7 @@
3635
assert_array_equal,
3736
)
3837
from sklearn.utils.extmath import softmax
38+
from sklearn.utils.fixes import CSR_CONTAINERS
3939
from sklearn.utils.validation import (
4040
check_array,
4141
check_consistent_length,
@@ -1762,10 +1762,12 @@ def test_label_ranking_loss():
17621762
(0 + 2 / 2 + 1 / 2) / 3.0,
17631763
)
17641764

1765-
# Sparse csr matrices
1765+
1766+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
1767+
def test_label_ranking_loss_sparse(csr_container):
17661768
assert_almost_equal(
17671769
label_ranking_loss(
1768-
csr_matrix(np.array([[0, 1, 0], [1, 1, 0]])), [[0.1, 10, -3], [3, 1, 3]]
1770+
csr_container(np.array([[0, 1, 0], [1, 1, 0]])), [[0.1, 10, -3], [3, 1, 3]]
17691771
),
17701772
(0 + 2 / 2) / 2.0,
17711773
)
@@ -2193,10 +2195,13 @@ def test_top_k_accuracy_score_error(y_true, y_score, labels, msg):
21932195
top_k_accuracy_score(y_true, y_score, k=2, labels=labels)
21942196

21952197

2196-
def test_label_ranking_avg_precision_score_should_allow_csr_matrix_for_y_true_input():
2198+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
2199+
def test_label_ranking_avg_precision_score_should_allow_csr_matrix_for_y_true_input(
2200+
csr_container,
2201+
):
21972202
# Test that label_ranking_avg_precision_score accept sparse y_true.
21982203
# Non-regression test for #22575
2199-
y_true = csr_matrix([[1, 0, 0], [0, 0, 1]])
2204+
y_true = csr_container([[1, 0, 0], [0, 0, 1]])
22002205
y_score = np.array([[0.5, 0.9, 0.6], [0, 0, 1]])
22012206
result = label_ranking_average_precision_score(y_true, y_score)
22022207
assert result == pytest.approx(2 / 3)

0 commit comments

Comments
 (0)
0