8000 TST Better tests in ridge float64 upcasting (#9125) · jwjohnson314/scikit-learn@3805b8f · GitHub
[go: up one dir, main page]

Skip to content

Commit 3805b8f

Browse files
massichJeremiah Johnson
authored andcommitted
TST Better tests in ridge float64 upcasting (scikit-learn#9125)
Also invert the solvers check to highlight the fact that new solver should support both 32 bit and 64 bit float by default from now on.
1 parent 187b957 commit 3805b8f

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

sklearn/linear_model/ridge.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,11 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
472472

473473
def fit(self, X, y, sample_weight=None):
474474

475-
if self.solver in ['svd', 'sparse_cg', 'cholesky', 'lsqr']:
476-
_dtype = [np.float64, np.float32]
477-
else:
475+
if self.solver in ('sag', 'saga'):
478476
_dtype = np.float64
477+
else:
478+
# all other solvers work at both float precision levels
479+
_dtype = [np.float64, np.float32]
479480

480481
X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], dtype=_dtype,
481482
multi_output=True, y_numeric=True)

sklearn/linear_model/tests/test_ridge.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -813,12 +813,12 @@ def test_dtype_match():
813813
ridge_64.fit(X_64, y_64)
814814
coef_64 = ridge_64.coef_
815815

816-
# Do all the checks at once, like this is easier to debug
817-
assert_almost_equal(ridge_32.coef_, ridge_64.coef_, decimal=5)
818-
819816
# Do the actual checks at once for easier debug
820-
assert_equal(coef_32.dtype, X_32.dtype)
821-
assert_equal(coef_64.dtype, X_64.dtype)
817+
assert coef_32.dtype == X_32.dtype
818+
assert coef_64.dtype == X_64.dtype
819+
assert ridge_32.predict(X_32).dtype == X_32.dtype
820+
assert ridge_64.predict(X_64).dtype == X_64.dtype
821+
assert_almost_equal(ridge_32.coef_, ridge_64.coef_, decimal 81C8 =5)
822822

823823

824824
def test_dtype_match_cholesky():
@@ -844,6 +844,8 @@ def test_dtype_match_cholesky():
844844
coef_64 = ridge_64.coef_
845845

846846
# Do all the checks at once, like this is easier to debug
847-
assert_equal(coef_32.dtype, X_32.dtype)
848-
assert_equal(coef_64.dtype, X_64.dtype)
847+
assert coef_32.dtype == X_32.dtype
848+
assert coef_64.dtype == X_64.dtype
849+
assert ridge_32.predict(X_32).dtype == X_32.dtype
850+
assert ridge_64.predict(X_64).dtype == X_64.dtype
849851
assert_almost_equal(ridge_32.coef_, ridge_64.coef_, decimal=5)

0 commit comments

Comments
 (0)
0