|
4 | 4 | import numpy as np
|
5 | 5 | import pytest
|
6 | 6 | from scipy import stats
|
7 |
| -from scipy.sparse import csr_matrix |
8 | 7 |
|
9 | 8 | from sklearn import datasets, svm
|
10 | 9 | from sklearn.datasets import make_multilabel_classification
|
|
36 | 35 | assert_array_equal,
|
37 | 36 | )
|
38 | 37 | from sklearn.utils.extmath import softmax
|
| 38 | +from sklearn.utils.fixes import CSR_CONTAINERS |
39 | 39 | from sklearn.utils.validation import (
|
40 | 40 | check_array,
|
41 | 41 | check_consistent_length,
|
@@ -1762,10 +1762,12 @@ def test_label_ranking_loss():
|
1762 | 1762 | (0 + 2 / 2 + 1 / 2) / 3.0,
|
1763 | 1763 | )
|
1764 | 1764 |
|
1765 |
| - # Sparse csr matrices |
| 1765 | + |
| 1766 | +@pytest.mark.parametrize("csr_container", CSR_CONTAINERS) |
| 1767 | +def test_label_ranking_loss_sparse(csr_container): |
1766 | 1768 | assert_almost_equal(
|
1767 | 1769 | label_ranking_loss(
|
1768 |
| - csr_matrix(np.array([[0, 1, 0], [1, 1, 0]])), [[0.1, 10, -3], [3, 1, 3]] |
| 1770 | + csr_container(np.array([[0, 1, 0], [1, 1, 0]])), [[0.1, 10, -3], [3, 1, 3]] |
1769 | 1771 | ),
|
1770 | 1772 | (0 + 2 / 2) / 2.0,
|
1771 | 1773 | )
|
@@ -2193,10 +2195,13 @@ def test_top_k_accuracy_score_error(y_true, y_score, labels, msg):
|
2193 | 2195 | top_k_accuracy_score(y_true, y_score, k=2, labels=labels)
|
2194 | 2196 |
|
2195 | 2197 |
|
2196 |
| -def test_label_ranking_avg_precision_score_should_allow_csr_matrix_for_y_true_input(): |
| 2198 | +@pytest.mark.parametrize("csr_container", CSR_CONTAINERS) |
| 2199 | +def test_label_ranking_avg_precision_score_should_allow_csr_matrix_for_y_true_input( |
| 2200 | + csr_container, |
| 2201 | +): |
2197 | 2202 | # Test that label_ranking_avg_precision_score accept sparse y_true.
|
2198 | 2203 | # Non-regression test for #22575
|
2199 |
| - y_true = csr_matrix([[1, 0, 0], [0, 0, 1]]) |
| 2204 | + y_true = csr_container([[1, 0, 0], [0, 0, 1]]) |
2200 | 2205 | y_score = np.array([[0.5, 0.9, 0.6], [0, 0, 1]])
|
2201 | 2206 | result = label_ranking_average_precision_score(y_true, y_score)
|
2202 | 2207 | assert result == pytest.approx(2 / 3)
|
|
0 commit comments