8000 TST Make test_ridge_regression_dtype_stability less random (#13816) · scikit-learn/scikit-learn@ff2f923 · GitHub
[go: up one dir, main page]

Skip to content

Commit ff2f923

Browse files
ogriseljnothman
authored andcommitted
TST Make test_ridge_regression_dtype_stability less random (#13816)
1 parent 6525a39 commit ff2f923

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

sklearn/linear_model/tests/test_ridge.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,16 +1124,18 @@ def test_dtype_match_cholesky():
11241124

11251125
@pytest.mark.parametrize(
11261126
'solver', ['svd', 'cholesky', 'lsqr', 'sparse_cg', 'sag', 'saga'])
1127-
def test_ridge_regression_dtype_stability(solver):
1128-
random_state = np.random.RandomState(0)
1127+
@pytest.mark.parametrize('seed', range(1))
1128+
def test_ridge_regression_dtype_stability(solver, seed):
1129+
random_state = np.random.RandomState(seed)
11291130
n_samples, n_features = 6, 5
11301131
X = random_state.randn(n_samples, n_features)
11311132
coef = random_state.randn(n_features)
1132-
y = np.dot(X, coef) + 0.01 * rng.randn(n_samples)
1133+
y = np.dot(X, coef) + 0.01 * random_state.randn(n_samples)
11331134
alpha = 1.0
1134-
rtol = 1e-2 if os.name == 'nt' and _IS_32BIT else 1e-5
1135-
11361135
results = dict()
1136+
# XXX: Sparse CG seems to be far less numerically stable than the
1137+
# others, maybe we should not enable float32 for this one.
1138+
atol = 1e-3 if solver == "sparse_cg" else 1e-5
11371139
for current_dtype in (np.float32, np.float64):
11381140
results[current_dtype] = ridge_regression(X.astype(current_dtype),
11391141
y.astype(current_dtype),
@@ -1148,4 +1150,4 @@ def test_ridge_regression_dtype_stability(solver):
11481150

11491151
assert results[np.float32].dtype == np.float32
11501152
assert results[np.float64].dtype == np.float64
1151-
assert_allclose(results[np.float32], results[np.float64], rtol=rtol)
1153+
assert_allclose(results[np.float32], results[np.float64], atol=atol)

0 commit comments

Comments
 (0)
0