|
36 | 36 | )
|
37 | 37 | from sklearn.utils._mocking import CheckingClassifier
|
38 | 38 | from sklearn.utils._testing import assert_almost_equal, assert_array_equal
|
| 39 | +from sklearn.utils.fixes import ( |
| 40 | + COO_CONTAINERS, |
| 41 | + CSC_CONTAINERS, |
| 42 | + CSR_CONTAINERS, |
| 43 | + DOK_CONTAINERS, |
| 44 | + LIL_CONTAINERS, |
| 45 | +) |
39 | 46 | from sklearn.utils.multiclass import check_classification_targets, type_of_target
|
40 | 47 |
|
41 | 48 | msg = "The default value for `force_alpha` will change"
|
@@ -160,52 +167,49 @@ def test_ovr_ovo_regressor():
|
160 | 167 | assert np.mean(pred == iris.target) > 0.9
|
161 | 168 |
|
162 | 169 |
|
163 |
| -def test_ovr_fit_predict_sparse(): |
164 |
| - for sparse in [ |
165 |
| - sp.csr_matrix, |
166 |
| - sp.csc_matrix, |
167 |
| - sp.coo_matrix, |
168 |
| - sp.dok_matrix, |
169 |
| - sp.lil_matrix, |
170 |
| - ]: |
171 |
| - base_clf = MultinomialNB(alpha=1) |
| 170 | +@pytest.mark.parametrize( |
| 171 | + "sparse_container", |
| 172 | + CSR_CONTAINERS + CSC_CONTAINERS + COO_CONTAINERS + DOK_CONTAINERS + LIL_CONTAINERS, |
| 173 | +) |
| 174 | +def test_ovr_fit_predict_sparse(sparse_container): |
| 175 | + base_clf = MultinomialNB(alpha=1) |
172 | 176 |
|
173 |
| - X, Y = datasets.make_multilabel_classification( |
174 |
| - n_samples=100, |
175 |
| - n_features=20, |
176 |
| - n_classes=5, |
177 |
| - n_labels=3, |
178 |
| - length=50, |
179 |
| - allow_unlabeled=True, |
180 |
| - random_state=0, |
181 |
| - ) |
| 177 | + X, Y = datasets.make_multilabel_classification( |
| 178 | + n_samples=100, |
| 179 | + n_features=20, |
| 180 | + n_classes=5, |
| 181 | + n_labels=3, |
| 182 | + length=50, |
| 183 | + allow_unlabeled=True, |
| 184 | + random_state=0, |
| 185 | + ) |
182 | 186 |
|
183 |
| - X_train, Y_train = X[:80], Y[:80] |
184 |
| - X_test = X[80:] |
| 187 | + X_train, Y_train = X[:80], Y[:80] |
| 188 | + X_test = X[80:] |
185 | 189 |
|
186 |
| - clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train) |
187 |
| - Y_pred = clf.predict(X_test) |
| 190 | + clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train) |
| 191 | + Y_pred = clf.predict(X_test) |
188 | 192 |
|
189 |
| - clf_sprs = OneVsRestClassifier(base_clf).fit(X_train, sparse(Y_train)) |
190 |
| - Y_pred_sprs = clf_sprs.predict(X_test) |
| 193 | + clf_sprs = OneVsRestClassifier(base_clf).fit(X_train, sparse_container(Y_train)) |
| 194 | + Y_pred_sprs = clf_sprs.predict(X_test) |
191 | 195 |
|
192 |
| - assert clf.multilabel_ |
193 |
| - assert sp.issparse(Y_pred_sprs) |
194 |
| - assert_array_equal(Y_pred_sprs.toarray(), Y_pred) |
| 196 | + assert clf.multilabel_ |
| 197 | + assert sp.issparse(Y_pred_sprs) |
| 198 | + assert_array_equal(Y_pred_sprs.toarray(), Y_pred) |
195 | 199 |
|
196 |
| - # Test predict_proba |
197 |
| - Y_proba = clf_sprs.predict_proba(X_test) |
| 200 | + # Test predict_proba |
| 201 | + Y_proba = clf_sprs.predict_proba(X_test) |
198 | 202 |
|
199 |
| - # predict assigns a label if the probability that the |
200 |
| - # sample has the label is greater than 0.5. |
201 |
| - pred = Y_proba > 0.5 |
202 |
| - assert_array_equal(pred, Y_pred_sprs.toarray()) |
| 203 | + # predict assigns a label if the probability that the |
| 204 | + # sample has the label is greater than 0.5. |
| 205 | + pred = Y_proba > 0.5 |
| 206 | + assert_array_equal(pred, Y_pred_sprs.toarray()) |
203 | 207 |
|
204 |
| - # Test decision_function |
205 |
| - clf = svm.SVC() |
206 |
| - clf_sprs = OneVsRestClassifier(clf).fit(X_train, sparse(Y_train)) |
207 |
| - dec_pred = (clf_sprs.decision_function(X_test) > 0).astype(int) |
208 |
| - assert_array_equal(dec_pred, clf_sprs.predict(X_test).toarray()) |
| 208 | + # Test decision_function |
| 209 | + clf = svm.SVC() |
| 210 | + clf_sprs = OneVsRestClassifier(clf).fit(X_train, sparse_container(Y_train)) |
| 211 | + dec_pred = (clf_sprs.decision_function(X_test) > 0).astype(int) |
| 212 | + assert_array_equal(dec_pred, clf_sprs.predict(X_test).toarray()) |
209 | 213 |
|
210 | 214 |
|
211 | 215 | def test_ovr_always_present():
|
@@ -723,11 +727,12 @@ def test_ecoc_float_y():
|
723 | 727 | ovo.fit(X, y)
|
724 | 728 |
|
725 | 729 |
|
726 |
| -def test_ecoc_delegate_sparse_base_estimator(): |
| 730 | +@pytest.mark.parametrize("csc_container", CSC_CONTAINERS) |
| 731 | +def test_ecoc_delegate_sparse_base_estimator(csc_container): |
727 | 732 | # Non-regression test for
|
728 | 733 | # https://github.com/scikit-learn/scikit-learn/issues/17218
|
729 | 734 | X, y = iris.data, iris.target
|
730 |
| - X_sp = sp.csc_matrix(X) |
| 735 | + X_sp = csc_container(X) |
731 | 736 |
|
732 | 737 | # create an estimator that does not support sparse input
|
733 | 738 | base_estimator = CheckingClassifier(
|
|
0 commit comments