8000 TST Extend tests for `scipy.sparse.*array` in `sklearn/linear_model/t… · REDVM/scikit-learn@2d93457 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2d93457

Browse files
TialoREDVM
authored andcommitted
TST Extend tests for scipy.sparse.*array in sklearn/linear_model/tests/test_sag.py (scikit-learn#27206)
1 parent 83881b0 commit 2d93457

File tree

1 file changed

+29
-20
lines changed

1 file changed

+29
-20
lines changed

sklearn/linear_model/tests/test_sag.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import numpy as np
1010
import pytest
11-
import scipy.sparse as sp
1211
from scipy.special import logsumexp
1312

1413
from sklearn._loss.loss import HalfMultinomialLoss
@@ -27,6 +26,7 @@
2726
assert_array_almost_equal,
2827
)
2928
from sklearn.utils.extmath import row_norms
29+
from sklearn.utils.fixes import CSR_CONTAINERS
3030

3131
iris = load_iris()
3232

@@ -356,7 +356,8 @@ def test_regressor_matching():
356356

357357

358358
@pytest.mark.filterwarnings("ignore:The max_iter was reached")
359-
def test_sag_pobj_matches_logistic_regression():
359+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
360+
def test_sag_pobj_matches_logistic_regression(csr_container):
360361
"""tests if the sag pobj matches log reg"""
361362
n_samples = 100
362363
alpha = 1.0
@@ -383,7 +384,7 @@ def test_sag_pobj_matches_logistic_regression():
383384
)
384385< 10000 /code>

385386
clf1.fit(X, y)
386-
clf2.fit(sp.csr_matrix(X), y)
387+
clf2.fit(csr_container(X), y)
387388
clf3.fit(X, y)
388389

389390
pobj1 = get_pobj(clf1.coef_, alpha, X, y, log_loss)
@@ -396,7 +397,8 @@ def test_sag_pobj_matches_logistic_regression():
396397

397398

398399
@pytest.mark.filterwarnings("ignore:The max_iter was reached")
399-
def test_sag_pobj_matches_ridge_regression():
400+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
401+
def test_sag_pobj_matches_ridge_regression(csr_container):
400402
"""tests if the sag pobj matches ridge reg"""
401403
n_samples = 100
402404
n_features = 10
@@ -427,7 +429,7 @@ def test_sag_pobj_matches_ridge_regression():
427429
)
428430

429431
clf1.fit(X, y)
430-
clf2.fit(sp.csr_matrix(X), y)
432+
clf2.fit(csr_container(X), y)
431433
clf3.fit(X, y)
432434

433435
pobj1 = get_pobj(clf1.coef_, alpha, X, y, squared_loss)
@@ -440,7 +442,8 @@ def test_sag_pobj_matches_ridge_regression():
440442

441443

442444
@pytest.mark.filterwarnings("ignore:The max_iter was reached")
443-
def test_sag_regressor_computed_correctly():
445+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
446+
def test_sag_regressor_computed_correctly(csr_container):
444447
"""tests if the sag regressor is computed correctly"""
445448
alpha = 0.1
446449
n_features = 10
@@ -465,7 +468,7 @@ def test_sag_regressor_computed_correctly():
465468
clf2 = clone(clf1)
466469

467470
clf1.fit(X, y)
468-
clf2.fit(sp.csr_matrix(X), y)
471+
clf2.fit(csr_container(X), y)
469472

470473
spweights1, spintercept1 = sag_sparse(
471474
X,
@@ -551,7 +554,8 @@ def test_get_auto_step_size():
551554

552555

553556
@pytest.mark.parametrize("seed", range(3)) # locally tested with 1000 seeds
554-
def test_sag_regressor(seed):
557+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
558+
def test_sag_regressor(seed, csr_container):
555559
"""tests if the sag regressor performs well"""
556560
xmin, xmax = -5, 5
557561
n_samples = 300
@@ -573,7 +577,7 @@ def test_sag_regressor(seed):
573577
)
574578
clf2 = clone(clf1)
575579
clf1.fit(X, y)
576-
clf2.fit(sp.csr_matrix(X), y)
580+
clf2.fit(csr_container(X), y)
577581
score1 = clf1.score(X, y)
578582
score2 = clf2.score(X, y)
579583
assert score1 > 0.98
@@ -585,15 +589,16 @@ def test_sag_regressor(seed):
585589
clf1 = Ridge(tol=tol, solver="sag", max_iter=max_iter, alpha=alpha * n_samples)
586590
clf2 = clone(clf1)
587591
clf1.fit(X, y)
588-
clf2.fit(sp.csr_matrix(X), y)
592+
clf2.fit(csr_container(X), y)
589593
score1 = clf1.score(X, y)
590594
score2 = clf2.score(X, y)
591595
assert score1 > 0.45
592596
assert score2 > 0.45
593597

594598

595599
@pytest.mark.filterwarnings("ignore:The max_iter was reached")
596-
def test_sag_classifier_computed_correctly():
600+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
601+
def test_sag_classifier_computed_correctly(csr_container):
597602
"""tests if the binary classifier is computed correctly"""
598603
alpha = 0.1
599604
n_samples = 50
@@ -619,7 +624,7 @@ def test_sag_classifier_computed_correctly():
619624
clf2 = clone(clf1)
620625

621626
clf1.fit(X, y)
622-
clf2.fit(sp.csr_matrix(X), y)
627+
clf2.fit(csr_container(X), y)
623628

624629
spweights, spintercept = sag_sparse(
625630
X,
@@ -649,7 +654,8 @@ def test_sag_classifier_computed_correctly():
649654

650655

651656
@pytest.mark.filterwarnings("ignore:The max_iter was reached")
652-
def test_sag_multiclass_computed_correctly():
657+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
658+
def test_sag_multiclass_computed_correctly(csr_container):
653659
"""tests if the multiclass classifier is computed correctly"""
654660
alpha = 0.1
655661
n_samples = 20
@@ -672,7 +678,7 @@ def test_sag_multiclass_computed_correctly():
672678
clf2 = clone(clf1)
673679

674680
clf1.fit(X, y)
675-
clf2.fit(sp.csr_matrix(X), y)
681+
clf2.fit(csr_container(X), y)
676682

677683
coef1 = []
678684
intercept1 = []
@@ -720,7 +726,8 @@ def test_sag_multiclass_computed_correctly():
720726
assert_almost_equal(clf2.intercept_[i], intercept2[i], decimal=1)
721727

722728

723-
def test_classifier_results():
729+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
730+
def test_classifier_results(csr_container):
724731
"""tests if classifier results match target"""
725732
alpha = 0.1
726733
n_features = 20
@@ -742,15 +749,16 @@ def test_classifier_results():
742749
clf2 = clone(clf1)
743750

744751
clf1.fit(X, y)
745-
clf2.fit(sp.csr_matrix(X), y)
752+
clf2.fit(csr_container(X), y)
746753
pred1 = clf1.predict(X)
747754
pred2 = clf2.predict(X)
748755
assert_almost_equal(pred1, y, decimal=12)
749756
assert_almost_equal(pred2, y, decimal=12)
750757

751758

752759
@pytest.mark.filterwarnings("ignore:The max_iter was reached")
753-
def test_binary_classifier_class_weight():
760+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
761+
def test_binary_classifier_class_weight(csr_container):
754762
"""tests binary classifier with classweights for each class"""
755763
alpha = 0.1
756764
n_samples = 50
@@ -778,7 +786,7 @@ def test_binary_classifier_class_weight():
778786
clf2 = clone(clf1)
779787

780788
clf1.fit(X, y)
781-
clf2.fit(sp.csr_matrix(X), y)
789+
clf2.fit(csr_container(X), y)
782790

783791
le = LabelEncoder()
784792
class_weight_ = compute_class_weight(class_weight, classes=np.unique(y), y=y)
@@ -813,7 +821,8 @@ def test_binary_classifier_class_weight():
813821

814822

815823
@pytest.mark.filterwarnings("ignore:The max_iter was reached")
816-
def test_multiclass_classifier_class_weight():
824+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
825+
def test_multiclass_classifier_class_weight(csr_container):
817826
"""tests multiclass with classweights for each class"""
818827
alpha = 0.1
819828
n_samples = 20
@@ -837,7 +846,7 @@ def test_multiclass_classifier_class_weight():
837846
)
838847
clf2 = clone(clf1)
839848
clf1.fit(X, y)
840-
clf2.fit(sp.csr_matrix(X), y)
849+
clf2.fit(csr_container(X), y)
841850

842851
le = LabelEncoder()
843852
class_weight_ = compute_class_weight(class_weight, classes=np.unique(y), y=y)

0 commit comments

Comments
 (0)
0