diff --git a/sklearn/svm/_bounds.py b/sklearn/svm/_bounds.py index f89b6a77fc616..83cb72d30892c 100644 --- a/sklearn/svm/_bounds.py +++ b/sklearn/svm/_bounds.py @@ -2,13 +2,25 @@ # Author: Paolo Losi # License: BSD 3 clause +from numbers import Real + import numpy as np from ..preprocessing import LabelBinarizer from ..utils.validation import check_consistent_length, check_array from ..utils.extmath import safe_sparse_dot +from ..utils._param_validation import StrOptions, Interval, validate_params +@validate_params( + { + "X": ["array-like", "sparse matrix"], + "y": ["array-like"], + "loss": [StrOptions({"squared_hinge", "log"})], + "fit_intercept": ["boolean"], + "intercept_scaling": [Interval(Real, 0, None, closed="neither")], + } +) def l1_min_c(X, y, *, loss="squared_hinge", fit_intercept=True, intercept_scaling=1.0): """Return the lowest bound for C. @@ -49,8 +61,6 @@ def l1_min_c(X, y, *, loss="squared_hinge", fit_intercept=True, intercept_scalin l1_min_c : float Minimum value for C. """ - if loss not in ("squared_hinge", "log"): - raise ValueError('loss type not in ("squared_hinge", "log")') X = check_array(X, accept_sparse="csc") check_consistent_length(X, y) diff --git a/sklearn/svm/tests/test_bounds.py b/sklearn/svm/tests/test_bounds.py index 5ca0f12b5d7f0..23d6be2f44e98 100644 --- a/sklearn/svm/tests/test_bounds.py +++ b/sklearn/svm/tests/test_bounds.py @@ -35,13 +35,6 @@ def test_l1_min_c(loss, X_label, Y_label, intercept_label): check_l1_min_c(X, Y, loss, **intercept_params) -def test_l1_min_c_l2_loss(): - # loss='l2' should raise ValueError - msg = "loss type not in" - with pytest.raises(ValueError, match=msg): - l1_min_c(dense_X, Y1, loss="l2") - - def check_l1_min_c(X, y, loss, fit_intercept=True, intercept_scaling=1.0): min_c = l1_min_c( X, @@ -76,11 +69,6 @@ def test_ill_posed_min_c(): l1_min_c(X, y) -def test_unsupported_loss(): - with pytest.raises(ValueError): - l1_min_c(dense_X, Y1, loss="l1") - - _MAX_UNSIGNED_INT = 4294967295 diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d4e645c052dab..6ea4f5ce73f67 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -10,6 +10,7 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", + "sklearn.svm.l1_min_c", ]