8000 Add tests for `alphas` validation in `RidgeClassifierCV` · scikit-learn/scikit-learn@c364401 · GitHub
[go: up one dir, main page]

Skip to content

Commit c364401

Browse files
author
ArturoAmorQ
committed
Add tests for alphas validation in RidgeClassifierCV
1 parent c76ac26 commit c364401

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

sklearn/linear_model/tests/test_ridge.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,34 @@ def test_ridgecv_scalar_alphas():
12621262
ridge.fit(X, y)
12631263

12641264

1265+
def test_ridgeclassifiercv_scalar_alphas():
1266+
X, Y = make_multilabel_classification(n_classes=1, random_state=0)
1267+
Y = Y.reshape(-1, 1)
1268+
y = np.concatenate([Y, Y], axis=1)
1269+
# The method for fitting _BaseRidgeCV depends whether cv=None
1270+
cv = KFold(3)
1271+
1272+
clf = RidgeClassifierCV(alphas=(1, -1, -100))
1273+
with pytest.raises(ValueError, match=r"alphas\[1\] == -1, must be > 0.0"):
1274+
clf.fit(X, y)
1275+
1276+
# Negative floats and cv is not A5DF None
1277+
clf = RidgeClassifierCV(alphas=(-0.1, -1.0, -10.0), cv=cv)
1278+
with pytest.raises(ValueError, match=r"alphas\[0\] == -0.1, must be > 0.0"):
1279+
clf.fit(X, y)
1280+
1281+
# Strings
1282+
clf = RidgeClassifierCV(alphas=(1, 1.0, "1"))
1283+
with pytest.raises(
1284+
TypeError,
1285+
match=(
1286+
r"alphas\[2\] must be an instance of <class 'numbers.Real'>, not <class"
1287+
" 'str'>"
1288+
),
1289+
):
1290+
clf.fit(X, y)
1291+
1292+
12651293
def test_raises_value_error_if_solver_not_supported():
12661294
# Tests whether a ValueError is raised if a non-identified solver
12671295
# is passed to ridge_regression

0 commit comments

Comments
 (0)
0