8000 TST Extend tests for `scipy.sparse.*array` in `sklearn/tests/test_mul… · scikit-learn/scikit-learn@749136e · GitHub
[go: up one dir, main page]

Skip to content

Commit 749136e

Browse files
TialoOmarManzoorglemaitre
authored
TST Extend tests for scipy.sparse.*array in sklearn/tests/test_multiclass.py (#27223)
Co-authored-by: Omar Salman <omar.salman@arbisoft.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 0ee7553 commit 749136e

File tree

1 file changed

+45
-40
lines changed

1 file changed

+45
-40
lines changed

sklearn/tests/test_multiclass.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@
3636
)
3737
from sklearn.utils._mocking import CheckingClassifier
3838
from sklearn.utils._testing import assert_almost_equal, assert_array_equal
39+
from sklearn.utils.fixes import (
40+
COO_CONTAINERS,
41+
CSC_CONTAINERS,
42+
CSR_CONTAINERS,
43+
DOK_CONTAINERS,
44+
LIL_CONTAINERS,
45+
)
3946
from sklearn.utils.multiclass import check_classification_targets, type_of_target
4047

4148
msg = "The default value for `force_alpha` will change"
@@ -160,52 +167,49 @@ def test_ovr_ovo_regressor():
160167
assert np.mean(pred == iris.target) > 0.9
161168

162169

163-
def test_ovr_fit_predict_sparse():
164-
for sparse in [
165-
sp.csr_matrix,
166-
sp.csc_matrix,
167-
sp.coo_matrix,
168-
sp.dok_matrix,
169-
sp.lil_matrix,
170-
]:
171-
base_clf = MultinomialNB(alpha=1)
170+
@pytest.mark.parametrize(
171+
"sparse_container",
172+
CSR_CONTAINERS + CSC_CONTAINERS + COO_CONTAINERS + DOK_CONTAINERS + LIL_CONTAINERS,
173+
)
174+
def test_ovr_fit_predict_sparse(sparse_container):
175+
base_clf = MultinomialNB(alpha=1)
172176

173-
X, Y = datasets.make_multilabel_classification(
174-
n_samples=100,
175-
n_features=20,
176-
n_classes=5,
177-
n_labels=3,
178-
length=50,
179-
allow_unlabeled=True,
180-
random_state=0,
181-
)
177+
X, Y = datasets.make_multilabel_classification(
178+
n_samples=100,
179+
n_features=20,
180+
n_classes=5,
181+
n_labels=3,
182+
length=50,
183+
allow_unlabeled=True,
184+
random_state=0,
185+
)
182186

183-
X_train, Y_train = X[:80], Y[:80]
184-
X_test = X[80:]
187+
X_train, Y_train = X[:80], Y[:80]
188+
X_test = X[80:]
185189

186-
clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train)
187-
Y_pred = clf.predict(X_test)
190+
clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train)
191+
Y_pred = clf.predict(X_test)
188192

189-
clf_sprs = OneVsRestClassifier(base_clf).fit(X_train, sparse(Y_train))
190-
Y_pred_sprs = clf_sprs.predict(X_test)
193+
clf_sprs = OneVsRestClassifier(base_clf).fit(X_train, sparse_container(Y_train))
194+
Y_pred_sprs = clf_sprs.predict(X_test)
191195

192-
assert clf.multilabel_
193-
assert sp.issparse(Y_pred_sprs)
194-
assert_array_equal(Y_pred_sprs.toarray(), Y_pred)
196+
assert clf.multilabel_
197+
assert sp.issparse(Y_pred_sprs)
198+
assert_array_equal(Y_pred_sprs.toarray(), Y_pred)
195199

196-
# Test predict_proba
197-
Y_proba = clf_sprs.predict_proba(X_test)
200+
# Test predict_proba
201+
Y_proba = clf_sprs.predict_proba(X_test)
198202

199-
# predict assigns a label if the probability that the
200-
# sample has the label is greater than 0.5.
201-
pred = Y_proba > 0.5
202-
assert_array_equal(pred, Y_pred_sprs.toarray())
203+
# predict assigns a label if the probability that the
204+
# sample has the label is greater than 0.5.
205+
pred = Y_proba > 0.5
206+
assert_array_equal(pred, Y_pred_sprs.toarray())
203207

204-
# Test decision_function
205-
clf = svm.SVC()
206-
clf_sprs = OneVsRestClassifier(clf).fit(X_train, sparse(Y_train))
207-
dec_pred = (clf_sprs.decision_function(X_test) > 0).astype(int)
208-
assert_array_equal(dec_pred, clf_sprs.predict(X_test).toarray())
208+
# Test decision_function
209+
clf = svm.SVC()
210+
clf_sprs = OneVsRestClassifier(clf).fit(X_train, sparse_container(Y_train))
211+
dec_pred = (clf_sprs.decision_function(X_test) > 0).astype(int)
212+
assert_array_equal(dec_pred, clf_sprs.predict(X_test).toarray())
209213

210214

211215
def test_ovr_always_present():
@@ -723,11 +727,12 @@ def test_ecoc_float_y():
723727
ovo.fit(X, y)
724728

725729

726-
def test_ecoc_delegate_sparse_base_estimator():
730+
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
731+
def test_ecoc_delegate_sparse_base_estimator(csc_container):
727732
# Non-regression test for
728733
# https://github.com/scikit-learn/scikit-learn/issues/17218
729734
X, y = iris.data, iris.target
730-
X_sp = sp.csc_matrix(X)
735+
X_sp = csc_container(X)
731736

732737
# create an estimator that does not support sparse input
733738
base_estimator = CheckingClassifier(

0 commit comments

Comments
 (0)
0