|
36 | 36 | def _solve_sparse_cg(X, y, alpha, max_iter=None, tol=1e-3, verbose=0):
|
37 | 37 | n_samples, n_features = X.shape
|
38 | 38 | X1 = sp_linalg.aslinearoperator(X)
|
39 |
| - coefs = np.empty((y.shape[1], n_features)) |
| 39 | + coefs = np.empty((y.shape[1], n_features), dtype=X.dtype) |
40 | 40 |
|
41 | 41 | if n_features > n_samples:
|
42 | 42 | def create_mv(curr_alpha):
|
@@ -80,7 +80,7 @@ def _mv(x):
|
80 | 80 |
|
81 | 81 | def _solve_lsqr(X, y, alpha, max_iter=None, tol=1e-3):
|
82 | 82 | n_samples, n_features = X.shape
|
83 |
| - coefs = np.empty((y.shape[1], n_features)) |
| 83 | + coefs = np.empty((y.shape[1], n_features), dtype=X.dtype) |
84 | 84 | n_iter = np.empty(y.shape[1], dtype=np.int32)
|
85 | 85 |
|
86 | 86 | # According to the lsqr documentation, alpha = damp^2.
|
@@ -111,7 +111,7 @@ def _solve_cholesky(X, y, alpha):
|
111 | 111 | return linalg.solve(A, Xy, sym_pos=True,
|
112 | 112 | overwrite_a=True).T
|
113 | 113 | else:
|
114 |
| - coefs = np.empty([n_targets, n_features]) |
| 114 | + coefs = np.empty([n_targets, n_features], dtype=X.dtype) |
115 | 115 | for coef, target, current_alpha in zip(coefs, Xy.T, alpha):
|
116 | 116 | A.flat[::n_features + 1] += current_alpha
|
117 | 117 | coef[:] = linalg.solve(A, target, sym_pos=True,
|
@@ -186,7 +186,7 @@ def _solve_svd(X, y, alpha):
|
186 | 186 | idx = s > 1e-15 # same default value as scipy.linalg.pinv
|
187 | 187 | s_nnz = s[idx][:, np.newaxis]
|
188 | 188 | UTy = np.dot(U.T, y)
|
189 |
| - d = np.zeros((s.size, alpha.size)) |
| 189 | + d = np.zeros((s.size, alpha.size), dtype=X.dtype) |
190 | 190 | d[idx] = s_nnz / (s_nnz ** 2 + alpha)
|
191 | 191 | d_UT_y = d * UTy
|
192 | 192 | return np.dot(Vt.T, d_UT_y).T
|
@@ -371,7 +371,7 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
|
371 | 371 | X, y = _rescale_data(X, y, sample_weight)
|
372 | 372 |
|
373 | 373 | # There should be either 1 or n_targets penalties
|
374 |
| - alpha = np.asarray(alpha).ravel() |
| 374 | + alpha = np.asarray(alpha, dtype=X.dtype).ravel() |
375 | 375 | if alpha.size not in [1, n_targets]:
|
376 | 376 | raise ValueError("Number of targets and number of penalties "
|
377 | 377 | "do not correspond: %d != %d"
|
@@ -469,7 +469,13 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
|
469 | 469 | self.random_state = random_state
|
470 | 470 |
|
471 | 471 | def fit(self, X, y, sample_weight=None):
|
472 |
| - X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], dtype=np.float64, |
| 472 | + |
| 473 | + if self.solver in ['svd', 'sparse_cg', 'cholesky', 'lsqr']: |
| 474 | + _dtype = [np.float64, np.float32] |
| 475 | + else: |
| 476 | + _dtype = np.float64 |
| 477 | + |
| 478 | + X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], dtype=_dtype, |
473 | 479 | multi_output=True, y_numeric=True)
|
474 | 480 |
|
475 | 481 | if ((sample_weight is not None) and
|
|
0 commit comments