8000 API fix params validation in SGD inherited models (#20683) · scikit-learn/scikit-learn@d44fd44 · GitHub
[go: up one dir, main page]

Skip to content

Commit d44fd44

Browse files
authored
API fix params validation in SGD inherited models (#20683)
1 parent 4637682 commit d44fd44

File tree

4 files changed

+137
-168
lines changed

4 files changed

+137
-168
lines changed

doc/whats_new/v1.0.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,18 @@ Changelog
527527
coordinate descent solver. Otherwise, an error will be raised.
528528
:pr:`19391` by :user:`Shao Yang Hong <hongshaoyang>`.
529529

530+
- |API| Keyword validation has moved from `__init__` and `set_params` to `fit`
531+
for the following estimators conforming to scikit-learn's conventions:
532+
:class:`linear_model.SGDClassifier`,
533+
:class:`linear_model.SparseSGDClassifier`,
534+
:class:`linear_model.SGDRegressor`,
535+
:class:`linear_model.SparseSGDRegressor`,
536+
:class:`linear_model.SGDOneClassSVM`,
537+
:class:`linear_model.SparseSGDOneClassSVM`,
538+
:class:`linear_model.PassiveAggressiveClassifier`,
539+
:class:`linear_model.PassiveAggressiveRegressor`.
540+
:pr:`20683` by `Guillaume Lemaitre`_.
541+
530542
:mod:`sklearn.manifold`
531543
.......................
532544

sklearn/linear_model/_stochastic_gradient.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,26 +121,6 @@ def __init__(
121121
self.average = average
122122
self.max_iter = max_iter
123123
self.tol = tol
124-
# current tests expect init to do parameter validation
125-
# but we are not allowed to set attributes
126-
self._validate_params()
127-
128-
def set_params(self, **kwargs):
129-
"""Set and validate the parameters of estimator.
130-
131-
Parameters
132-
----------
133-
**kwargs : dict
134-
Estimator parameters.
135-
136-
Returns
137-
-------
138-
self : object
139-
Estimator instance.
140-
"""
141-
super().set_params(**kwargs)
142-
self._validate_params()
143-
return self
144124

145125
@abstractmethod
146126
def fit(self, X, y):

sklearn/linear_model/tests/test_passive_aggressive.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55

6+
from sklearn.base import is_classifier
67
from sklearn.utils._testing import assert_array_almost_equal
78
from sklearn.utils._testing import assert_array_equal
89
from sklearn.utils._testing import assert_almost_equal
@@ -136,11 +137,13 @@ def test_classifier_correctness(loss):
136137
assert_array_almost_equal(clf1.w, clf2.coef_.ravel(), decimal=2)
137138

138139

139-
def test_classifier_undefined_methods():
140+
@pytest.mark.parametrize(
141+
"response_method", ["predict_proba", "predict_log_proba", "transform"]
142+
)
143+
def test_classifier_undefined_methods(response_method):
140144
clf = PassiveAggressiveClassifier(max_iter=100)
141-
for meth in ("predict_proba", "predict_log_proba", "transform"):
142-
with pytest.raises(AttributeError):
143-
getattr(clf, meth)
145+
with pytest.raises(AttributeError):
146+
getattr(clf, response_method)
144147

145148

146149
def test_class_weights():
@@ -279,6 +282,37 @@ def test_regressor_correctness(loss):
279282

280283
def test_regressor_undefined_methods():
281284
reg = PassiveAggressiveRegressor(max_iter=100)
282-
for meth in ("transform",):
283-
with pytest.raises(AttributeError):
284-
getattr(reg, meth)
285+
with pytest.raises(AttributeError):
286+
reg.transform(X)
287+
288+
289+
@pytest.mark.parametrize(
290+
"klass", [PassiveAggressiveClassifier, PassiveAggressiveRegressor]
291+
)
292+
@pytest.mark.parametrize("fit_method", ["fit", "partial_fit"])
293+
@pytest.mark.parametrize(
294+
"params, err_msg",
295+
[
296+
({"loss": "foobar"}, "The loss foobar is not supported"),
297+
({"max_iter": -1}, "max_iter must be > zero"),
298+
({"shuffle": "false"}, "shuffle must be either True or False"),
299+
({"early_stopping": "false"}, "early_stopping must be either True or False"),
300+
(
301+
{"validation_fraction": -0.1},
302+
r"validation_fraction must be in range \(0, 1\)",
303+
),
304+
({"n_iter_no_change": 0}, "n_iter_no_change must be >= 1"),
305+
],
306+
)
307+
def test_passive_aggressive_estimator_params_validation(
308+
klass, fit_method, params, err_msg
309+
):
310+
"""Validate parameters in the different PassiveAggressive estimators."""
311+
sgd_estimator = klass(**params)
312+
313+
with pytest.raises(ValueError, match=err_msg):
314+
if is_classifier(sgd_estimator) and fit_method == "partial_fit":
315+
fit_params = {"classes": np.unique(y)}
316+
else:
317+
fit_params = {}
318+
getattr(sgd_estimator, fit_method)(X, y, **fit_params)

sklearn/linear_model/tests/test_sgd.py

Lines changed: 84 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import pickle
2-
import pytest
32

3+
import joblib
4+
import pytest
45
import numpy as np
5-
from numpy.testing import assert_allclose
66
import scipy.sparse as sp
7-
import joblib
87

8+
from sklearn.utils._testing import assert_allclose
99
from sklearn.utils._testing import assert_array_equal
1010
from sklearn.utils._testing import assert_almost_equal
1111
from sklearn.utils._testing import assert_array_almost_equal
@@ -216,30 +216,55 @@ def asgd(klass, X, y, eta, alpha, weight_init=None, intercept_init=0.0):
216216

217217

218218
@pytest.mark.parametrize(
219-
"klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor]
220-
)
221-
def test_sgd_bad_alpha(klass):
222-
# Check whether expected ValueError on bad alpha
223-
with pytest.raises(ValueError):
224-
klass(alpha=-0.1)
225-
226-
227-
@pytest.mark.parametrize(
228-
"klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor]
219+
"klass",
220+
[
221+
SGDClassifier,
222+
SparseSGDClassifier,
223+
SGDRegressor,
224+
SparseSGDRegressor,
225+
SGDOneClassSVM,
226+
SparseSGDOneClassSVM,
227+
],
229228
)
230-
def test_sgd_bad_penalty(klass):
231-
# Check whether expected ValueError on bad penalty
232-
with pytest.raises(ValueError):
233-
klass(penalty="foobar", l1_ratio=0.85)
234-
235-
229+
@pytest.mark.parametrize("fit_method", ["fit", "partial_fit"])
236230
@pytest.mark.parametrize(
237-
"klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor]
231+
"params, err_msg",
232+
[
233+
({"alpha": -0.1}, "alpha must be >= 0"),
234+
({"penalty": "foobar", "l1_ratio": 0.85}, "Penalty foobar is not supported"),
235+
({"loss": "foobar"}, "The loss foobar is not supported"),
236+
({"l1_ratio": 1.1}, r"l1_ratio must be in \[0, 1\]"),
237+
({"learning_rate": "<unknown>"}, "learning rate <unknown> is not supported"),
238+
({"nu": -0.5}, r"nu must be in \(0, 1]"),
239+
({"nu": 2}, r"nu must be in \(0, 1]"),
240+
({"alpha": 0, "learning_rate": "optimal"}, "alpha must be > 0"),
241+
({"eta0": 0, "learning_rate": "constant"}, "eta0 must be > 0"),
242+
({"max_iter": -1}, "max_iter must be > zero"),
243+
({"shuffle": "false"}, "shuffle must be either True or False"),
244+
({"early_stopping": "false"}, "early_stopping must be either True or False"),
245+
(
246+
{"validation_fraction": -0.1},
247+
r"validation_fraction must be in range \(0, 1\)",
248+
),
249+
({"n_iter_no_change": 0}, "n_iter_no_change must be >= 1"),
250+
],
238251
)
239-
def test_sgd_bad_loss(klass):
240-
# Check whether expected ValueError on bad loss
241-
with pytest.raises(ValueError):
242-
klass(loss="foobar")
252+
def test_sgd_estimator_params_validation(klass, fit_method, params, err_msg):
253+
"""Validate parameters in the different SGD estimators."""
254+
try:
255+
sgd_estimator = klass(**params)
256+
except TypeError as err:
257+
if "__init__() got an unexpected keyword argument" in str(err):
258+
# skip test if the parameter is not supported by the estimator
259+
return
260+
raise err
261+
262+
with pytest.raises(ValueError, match=err_msg):
263+
if is_classifier(sgd_estimator) and fit_method == "partial_fit":
264+
fit_params = {"classes": np.unique(Y)}
265+
else:
266+
fit_params = {}
267+
getattr(sgd_estimator, fit_method)(X, Y, **fit_params)
243268

244269

245270
def _test_warm_start(klass, X, Y, lr):
@@ -408,16 +433,6 @@ def test_late_onset_averaging_reached(klass):
408433
assert_almost_equal(clf1.intercept_, average_intercept, decimal=16)
409434

410435

411-
@pytest.mark.parametrize(
412-
"klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor]
413-
)
414-
def test_sgd_bad_alpha_for_optimal_learning_rate(klass):
415-
# Check whether expected ValueError on bad alpha, i.e. 0
416-
# since alpha is used to compute the optimal learning rate
417-
with pytest.raises(ValueError):
418-
klass(alpha=0, learning_rate="optimal")
419-
420-
421436
@pytest.mark.parametrize(
422437
"klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor]
423438
)
@@ -540,115 +555,56 @@ def test_sgd_clf(klass):
540555
assert_array_equal(clf.predict(T), true_result)
541556

542557

543-
@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier])
544-
def test_sgd_bad_l1_ratio(klass):
545-
# Check whether expected ValueError on bad l1_ratio
546-
with pytest.raises(ValueError):
547-
klass(l1_ratio=1.1)
548-
549-
550-
@pytest.mark.parametrize(
551-
"klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM]
552-
)
553-
def test_sgd_bad_learning_rate_schedule(klass):
554-
# Check whether expected ValueError on bad learning_rate
555-
with pytest.raises(ValueError):
556-
klass(learning_rate="<unknown>")
557-
558-
559-
@pytest.mark.parametrize(
560-
"klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM]
561-
)
562-
def test_sgd_bad_eta0(klass):
563-
# Check whether expected ValueError on bad eta0
564-
with pytest.raises(ValueError):
565-
klass(eta0=0, learning_rate="constant")
566-
567-
568-
@pytest.mark.parametrize(
569-
"klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM]
570-
)
571-
def test_sgd_max_iter_param(klass):
572-
# Test parameter validity check
573-
with pytest.raises(ValueError):
574-
klass(max_iter=-10000)
575-
576-
577-
@pytest.mark.parametrize(
578-
"klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM]
579-
)
580-
def test_sgd_shuffle_param(klass):
581-
# Test parameter validity check
582-
with pytest.raises(ValueError):
583-
klass(shuffle="false")
584-
585-
586-
@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier])
587-
def test_sgd_early_stopping_param(klass):
588-
# Test parameter validity check
589-
with pytest.raises(ValueError):
590-
klass(early_stopping="false")
591-
592-
593-
@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier])
594-
def test_sgd_validation_fraction(klass):
595-
# Test parameter validity check
596-
with pytest.raises(ValueError):
597-
klass(validation_fraction=-0.1)
598-
599-
600-
@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier])
601-
def test_sgd_n_iter_no_change(klass):
602-
# Test parameter validity check
603-
with pytest.raises(ValueError):
604-
klass(n_iter_no_change=0)
605-
606-
607-
@pytest.mark.parametrize(
608-
"klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM]
609-
)
610-
def test_argument_coef(klass):
611-
# Checks coef_init not allowed as model argument (only fit)
612-
# Provided coef_ does not match dataset
613-
with pytest.raises(TypeError):
614-
klass(coef_init=np.zeros((3,)))
615-
616-
617558
@pytest.mark.parametrize(
618559
"klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM]
619560
)
620561
def test_provide_coef(klass):
621-
# Checks coef_init shape for the warm starts
622-
# Provided coef_ does not match dataset.
623-
with pytest.raises(ValueError):
562+
"""Check that the shape of `coef_init` is validated."""
563+
with pytest.raises(ValueError, match="Provided coef_init does not match dataset"):
624564
klass().fit(X, Y, coef_init=np.zeros((3,)))
625565

626566

627567
@pytest.mark.parametrize(
628-
"klass", [SGDClassifier, SparseSGDClassifier, SGDOneClassSVM, SparseSGDOneClassSVM]
568+
"klass, fit_params",
569+
[
570+
(SGDClassifier, {"intercept_init": np.zeros((3,))}),
571+
(SparseSGDClassifier, {"intercept_init": np.zeros((3,))}),
572+
(SGDOneClassSVM, {"offset_init": np.zeros((3,))}),
573+
(SparseSGDOneClassSVM, {"offset_init": np.zeros((3,))}),
574+
],
629575
)
630-
def test_set_intercept(klass):
631-
# Checks intercept_ shape for the warm starts
632-
# Provided intercept_ does not match dataset.
633-
if klass in [SGDClassifier, SparseSGDClassifier]:
634-
with pytest.raises(ValueError):
635-
klass().fit(X, Y, intercept_init=np.zeros((3,)))
636-
elif klass in [SGDOneClassSVM, SparseSGDOneClassSVM]:
637-
with pytest.raises(ValueError):
638-
klass().fit(X, Y, offset_init=np.zeros((3,)))
576+
def test_set_intercept_offset(klass, fit_params):
577+
"""Check that `intercept_init` or `offset_init` is validated."""
578+
sgd_estimator = klass()
579+
with pytest.raises(ValueError, match="does not match dataset"):
580+
sgd_estimator.fit(X, Y, **fit_params)
639581

640582

641-
@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier])
583+
@pytest.mark.parametrize(
584+
"klass", [SGDClassifier, SparseSGDClassifier, SGDRegressor, SparseSGDRegressor]
585+
)
642586
def test_sgd_early_stopping_with_partial_fit(klass):
643-
# Test parameter validity check
644-
with pytest.raises(ValueError):
587+
"""Check that we raise an error for `early_stopping` used with
588+
`partial_fit`.
589+
"""
590+
err_msg = "early_stopping should be False with partial_fit"
591+
with pytest.raises(ValueError, match=err_msg):
645592
klass(early_stopping=True).partial_fit(X, Y)
646593

647594

648-
@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier])
649-
def test_set_intercept_binary(klass):
650-
# Checks intercept_ shape for the warm starts in binary case
651-
klass().fit(X5, Y5, intercept_init=0)
595+
@pytest.mark.parametrize(
596+
"klass, fit_params",
597+
[
598+
(SGDClassifier, {"intercept_init": 0}),
599+
(SparseSGDClassifier, {"intercept_init": 0}),
600+
(SGDOneClassSVM, {"offset_init": 0}),
601+
(SparseSGDOneClassSVM, {"offset_init": 0}),
602+
],
603+
)
604+
def test_set_intercept_offset_binary(klass, fit_params):
605+
"""Check that we can pass a scaler with binary classification to
606+
`intercept_init` or `offset_init`."""
607+
klass().fit(X5, Y5, **fit_params)
652608

653609

654610
@pytest.mark.parametrize("klass", [SGDClassifier, SparseSGDClassifier])
@@ -1537,19 +1493,6 @@ def asgd_oneclass(klass, X, eta, nu, coef_init=None, offset_init=0.0):
15371493
return average_coef, 1 - average_intercept
15381494

15391495

1540-
@pytest.mark.parametrize("klass", [SGDOneClassSVM, SparseSGDOneClassSVM])
1541-
@pytest.mark.parametrize("nu", [-0.5, 2])
1542-
def test_bad_nu_values(klass, nu):
1543-
msg = r"nu must be in \(0, 1]"
1544-
with pytest.raises(ValueError, match=msg):
1545-
klass(nu=nu)
1546-
1547-
clf = klass(nu=0.05)
1548-
clf2 = clone(clf)
1549-
with pytest.raises(ValueError, match=msg):
1550-
clf2.set_params(nu=nu)
1551-
1552-
15531496
@pytest.mark.parametrize("klass", [SGDOneClassSVM, SparseSGDOneClassSVM])
15541497
def _test_warm_start_oneclass(klass, X, lr):
15551498
# Test that explicit warm restart...

0 commit comments

Comments
 (0)
0