|
1 | 1 | import pickle
|
2 |
| -import pytest |
3 | 2 |
|
| 3 | +import joblib |
| 4 | +import pytest |
4 | 5 | import numpy as np
|
5 |
| -from numpy.testing import assert_allclose |
6 | 6 | import scipy.sparse as sp
|
7 |
| -import joblib |
8 | 7 |
|
| 8 | +from sklearn.utils._testing import assert_allclose |
9 | 9 | from sklearn.utils._testing import assert_array_equal
|
10 | 10 | from sklearn.utils._testing import assert_almost_equal
|
11 | 11 | from sklearn.utils._testing import assert_array_almost_equal
|
@@ -216,30 +216,55 @@ def asgd(klass, X, y, eta, alpha, weight_init=None, intercept_init=0.0):
|
216 | 216 |
|
217 | 217 |
|
218 | 218 | @pytest.mark.parametrize(
|
219 |
| - "klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor] |
220 |
| -) |
221 |
| -def test_sgd_bad_alpha(klass): |
222 |
| - # Check whether expected ValueError on bad alpha |
223 |
| - with pytest.raises(ValueError): |
224 |
| - klass(alpha=-0.1) |
225 |
| - |
226 |
| - |
227 |
| -@pytest.mark.parametrize( |
228 |
| - "klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor] |
| 219 | + "klass", |
| 220 | + [ |
| 221 | + SGDClassifier, |
| 222 | + SparseSGDClassifier, |
| 223 | + SGDRegressor, |
| 224 | + SparseSGDRegressor, |
| 225 | + SGDOneClassSVM, |
| 226 | + SparseSGDOneClassSVM, |
| 227 | + ], |
229 | 228 | )
|
230 |
| -def test_sgd_bad_penalty(klass): |
231 |
| - # Check whether expected ValueError on bad penalty |
232 |
| - with pytest.raises(ValueError): |
233 |
| - klass(penalty="foobar", l1_ratio=0.85) |
234 |
| - |
235 |
| - |
| 229 | +@pytest.mark.parametrize("fit_method", ["fit", "partial_fit"]) |
236 | 230 | @pytest.mark.parametrize(
|
237 |
| - "klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor] |
| 231 | + "params, err_msg", |
| 232 | + [ |
| 233 | + ({"alpha": -0.1}, "alpha must be >= 0"), |
| 234 | + ({"penalty": "foobar", "l1_ratio": 0.85}, "Penalty foobar is not supported"), |
| 235 | + ({"loss": "foobar"}, "The loss foobar is not supported"), |
| 236 | + ({"l1_ratio": 1.1}, r"l1_ratio must be in \[0, 1\]"), |
| 237 | + ({"learning_rate": "<unknown>"}, "learning rate <unknown> is not supported"), |
| 238 | + ({"nu": -0.5}, r"nu must be in \(0, 1]"), |
| 239 | + ({"nu": 2}, r"nu must be in \(0, 1]"), |
| 240 | + ({"alpha": 0, "learning_rate": "optimal"}, "alpha must be > 0"), |
| 241 | + ({"eta0": 0, "learning_rate": "constant"}, "eta0 must be > 0"), |
| 242 | + ({"max_iter": -1}, "max_iter must be > zero"), |
| 243 | + ({"shuffle": "false"}, "shuffle must be either True or False"), |
| 244 | + ({"early_stopping": "false"}, "early_stopping must be either True or False"), |
| 245 | + ( |
| 246 | + {"validation_fraction": -0.1}, |
| 247 | + r"validation_fraction must be in range \(0, 1\)", |
| 248 | + ), |
| 249 | + ({"n_iter_no_change": 0}, "n_iter_no_change must be >= 1"), |
| 250 | + ], |
238 | 251 | )
|
239 |
| -def test_sgd_bad_loss(klass): |
240 |
| - # Check whether expected ValueError on bad loss |
241 |
| - with pytest.raises(ValueError): |
242 |
| - klass(loss="foobar") |
| 252 | +def test_sgd_estimator_params_validation(klass, fit_method, params, err_msg): |
| 253 | + """Validate parameters in the different SGD estimators.""" |
| 254 | + try: |
| 255 | + sgd_estimator = klass(**params) |
| 256 | + except TypeError as err: |
| 257 | + if "__init__() got an unexpected keyword argument" in str(err): |
| 258 | + # skip test if the parameter is not supported by the estimator |
| 259 | + return |
| 260 | + raise err |
| 261 | + |
| 262 | + with pytest.raises(ValueError, match=err_msg): |
| 263 | + if is_classifier(sgd_estimator) and fit_method == "partial_fit": |
| 264 | + fit_params = {"classes": np.unique(Y)} |
| 265 | + else: |
| 266 | + fit_params = {} |
| 267 | + getattr(sgd_estimator, fit_method)(X, Y, **fit_params) |
243 | 268 |
|
244 | 269 |
|
245 | 270 | def _test_warm_start(klass, X, Y, lr):
|
@@ -408,16 +433,6 @@ def test_late_onset_averaging_reached(klass):
|
408 | 433 | assert_almost_equal(clf1.intercept_, average_intercept, decimal=16)
|
409 | 434 |
|
410 | 435 |
|
411 |
| -@pytest.mark.parametrize( |
412 |
| - "klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor] |
413 |
| -) |
414 |
| -def test_sgd_bad_alpha_for_optimal_learning_rate(klass): |
415 |
| - # Check whether expected ValueError on bad alpha, i.e. 0 |
416 |
| - # since alpha is used to compute the optimal learning rate |
417 |
| - with pytest.raises(ValueError): |
418 |
| - klass(alpha=0, learning_rate="optimal") |
419 |
| - |
420 |
| - |
421 | 436 | @pytest.mark.parametrize(
|
422 | 437 | "klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor]
|
423 | 438 | )
|
@@ -540,115 +555,56 @@ def test_sgd_clf(klass):
|
540 | 555 | assert_array_equal(clf.predict(T), true_result)
|
541 | 556 |
|
542 | 557 |
|
543 |
| -@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier]) |
544 |
| -def test_sgd_bad_l1_ratio(klass): |
545 |
| - # Check whether expected ValueError on bad l1_ratio |
546 |
| - with pytest.raises(ValueError): |
547 |
| - klass(l1_ratio=1.1) |
548 |
| - |
549 |
| - |
550 |
| -@pytest.mark.parametrize( |
551 |
| - "klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM] |
552 |
| -) |
553 |
| -def test_sgd_bad_learning_rate_schedule(klass): |
554 |
| - # Check whether expected ValueError on bad learning_rate |
555 |
| - with pytest.raises(ValueError): |
556 |
| - klass(learning_rate="<unknown>") |
557 |
| - |
558 |
| - |
559 |
| -@pytest.mark.parametrize( |
560 |
| - "klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM] |
561 |
| -) |
562 |
| -def test_sgd_bad_eta0(klass): |
563 |
| - # Check whether expected ValueError on bad eta0 |
564 |
| - with pytest.raises(ValueError): |
565 |
| - klass(eta0=0, learning_rate="constant") |
566 |
| - |
567 |
| - |
568 |
| -@pytest.mark.parametrize( |
569 |
| - "klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM] |
570 |
| -) |
571 |
| -def test_sgd_max_iter_param(klass): |
572 |
| - # Test parameter validity check |
573 |
| - with pytest.raises(ValueError): |
574 |
| - klass(max_iter=-10000) |
575 |
| - |
576 |
| - |
577 |
| -@pytest.mark.parametrize( |
578 |
| - "klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM] |
579 |
| -) |
580 |
| -def test_sgd_shuffle_param(klass): |
581 |
| - # Test parameter validity check |
582 |
| - with pytest.raises(ValueError): |
583 |
| - klass(shuffle="false") |
584 |
| - |
585 |
| - |
586 |
| -@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier]) |
587 |
| -def test_sgd_early_stopping_param(klass): |
588 |
| - # Test parameter validity check |
589 |
| - with pytest.raises(ValueError): |
590 |
| - klass(early_stopping="false") |
591 |
| - |
592 |
| - |
593 |
| -@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier]) |
594 |
| -def test_sgd_validation_fraction(klass): |
595 |
| - # Test parameter validity check |
596 |
| - with pytest.raises(ValueError): |
597 |
| - klass(validation_fraction=-0.1) |
598 |
| - |
599 |
| - |
600 |
| -@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier]) |
601 |
| -def test_sgd_n_iter_no_change(klass): |
602 |
| - # Test parameter validity check |
603 |
| - with pytest.raises(ValueError): |
604 |
| - klass(n_iter_no_change=0) |
605 |
| - |
606 |
| - |
607 |
| -@pytest.mark.parametrize( |
608 |
| - "klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM] |
609 |
| -) |
610 |
| -def test_argument_coef(klass): |
611 |
| - # Checks coef_init not allowed as model argument (only fit) |
612 |
| - # Provided coef_ does not match dataset |
613 |
| - with pytest.raises(TypeError): |
614 |
| - klass(coef_init=np.zeros((3,))) |
615 |
| - |
616 |
| - |
617 | 558 | @pytest.mark.parametrize(
|
618 | 559 | "klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM]
|
619 | 560 | )
|
620 | 561 | def test_provide_coef(klass):
|
621 |
| - # Checks coef_init shape for the warm starts |
622 |
| - # Provided coef_ does not match dataset. |
623 |
| - with pytest.raises(ValueError): |
| 562 | + """Check that the shape of `coef_init` is validated.""" |
| 563 | + with pytest.raises(ValueError, match="Provided coef_init does not match dataset"): |
624 | 564 | klass().fit(X, Y, coef_init=np.zeros((3,)))
|
625 | 565 |
|
626 | 566 |
|
627 | 567 | @pytest.mark.parametrize(
|
628 |
| - "klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM] |
| 568 | + "klass, fit_params", |
| 569 | + [ |
| 570 | + (SGDClassifier, {"intercept_init": np.zeros((3,))}), |
| 571 | + (SparseSGDClassifier, {"intercept_init": np.zeros((3,))}), |
| 572 | + (SGDOneClassSVM, {"offset_init": np.zeros((3,))}), |
| 573 | + (SparseSGDOneClassSVM, {"offset_init": np.zeros((3,))}), |
| 574 | + ], |
629 | 575 | )
|
630 |
| -def test_set_intercept(klass): |
631 |
| - # Checks intercept_ shape for the warm starts |
632 |
| - # Provided intercept_ does not match dataset. |
633 |
| - if klass in [SGDClassifier, SparseSGDClassifier]: |
634 |
| - with pytest.raises(ValueError): |
635 |
| - klass().fit(X, Y, intercept_init=np.zeros((3,))) |
636 |
| - elif klass in [SGDOneClassSVM, SparseSGDOneClassSVM]: |
637 |
| - with pytest.raises(ValueError): |
638 |
| - klass().fit(X, Y, offset_init=np.zeros((3,))) |
| 576 | +def test_set_intercept_offset(klass, fit_params): |
| 577 | + """Check that `intercept_init` or `offset_init` is validated.""" |
| 578 | + sgd_estimator = klass() |
| 579 | + with pytest.raises(ValueError, match="does not match dataset"): |
| 580 | + sgd_estimator.fit(X, Y, **fit_params) |
639 | 581 |
|
640 | 582 |
|
641 |
| -@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier]) |
| 583 | +@pytest.mark.parametrize( |
| 584 | + "klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor] |
| 585 | +) |
642 | 586 | def test_sgd_early_stopping_with_partial_fit(klass):
|
643 |
| - # Test parameter validity check |
644 |
| - with pytest.raises(ValueError): |
| 587 | + """Check that we raise an error for `early_stopping` used with |
| 588 | + `partial_fit`. |
| 589 | + """ |
| 590 | + err_msg = "early_stopping should be False with partial_fit" |
| 591 | + with pytest.raises(ValueError, match=err_msg): |
645 | 592 | klass(early_stopping=True).partial_fit(X, Y)
|
646 | 593 |
|
647 | 594 |
|
648 |
| -@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier]) |
649 |
| -def test_set_intercept_binary(klass): |
650 |
| - # Checks intercept_ shape for the warm starts in binary case |
651 |
| - klass().fit(X5, Y5, intercept_init=0) |
| 595 | +@pytest.mark.parametrize( |
| 596 | + "klass, fit_params", |
| 597 | + [ |
| 598 | + (SGDClassifier, {"intercept_init": 0}), |
| 599 | + (SparseSGDClassifier, {"intercept_init": 0}), |
| 600 | + (SGDOneClassSVM, {"offset_init": 0}), |
| 601 | + (SparseSGDOneClassSVM, {"offset_init": 0}), |
| 602 | + ], |
| 603 | +) |
| 604 | +def test_set_intercept_offset_binary(klass, fit_params): |
| 605 | + """Check that we can pass a scaler with binary classification to |
| 606 | + `intercept_init` or `offset_init`.""" |
| 607 | + klass().fit(X5, Y5, **fit_params) |
652 | 608 |
|
653 | 609 |
|
654 | 610 | @pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier])
|
@@ -1537,19 +1493,6 @@ def asgd_oneclass(klass, X, eta, nu, coef_init=None, offset_init=0.0):
|
1537 | 1493 | return average_coef, 1 - average_intercept
|
1538 | 1494 |
|
1539 | 1495 |
|
1540 |
| -@pytest.mark.parametrize("klass", [SGDOneClassSVM, SparseSGDOneClassSVM]) |
1541 |
| -@pytest.mark.parametrize("nu", [-0.5, 2]) |
1542 |
| -def test_bad_nu_values(klass, nu): |
1543 |
| - msg = r"nu must be in \(0, 1]" |
1544 |
| - with pytest.raises(ValueError, match=msg): |
1545 |
| - klass(nu=nu) |
1546 |
| - |
1547 |
| - clf = klass(nu=0.05) |
1548 |
| - clf2 = clone(clf) |
1549 |
| - with pytest.raises(ValueError, match=msg): |
1550 |
| - clf2.set_params(nu=nu) |
1551 |
| - |
1552 |
| - |
1553 | 1496 | @pytest.mark.parametrize("klass", [SGDOneClassSVM, SparseSGDOneClassSVM])
|
1554 | 1497 | def _test_warm_start_oneclass(klass, X, lr):
|
1555 | 1498 | # Test that explicit warm restart...
|
|
0 commit comments