|
7 | 7 | import numpy as np
|
8 | 8 |
|
9 | 9 | from ..utils import check_random_state, check_array
|
10 |
| -from ..utils.validation import check_is_fitted |
| 10 | +from ..utils.validation import check_is_fitted, FLOAT_DTYPES |
11 | 11 | from ..linear_model import ridge_regression
|
12 | 12 | from ..base import BaseEstimator, TransformerMixin
|
13 | 13 | from .dict_learning import dict_learning, dict_learning_online
|
@@ -115,7 +115,7 @@ def fit(self, X, y=None):
|
115 | 115 | Returns the instance itself.
|
116 | 116 | """
|
117 | 117 | random_state = check_random_state(self.random_state)
|
118 |
| - X = check_array(X) |
| 118 | + X = check_array(X, dtype=FLOAT_DTYPES) |
119 | 119 | if self.n_components is None:
|
120 | 120 | n_components = X.shape[1]
|
121 | 121 | else:
|
@@ -168,7 +168,7 @@ def transform(self, X, ridge_alpha='deprecated'):
|
168 | 168 | """
|
169 | 169 | check_is_fitted(self, 'components_')
|
170 | 170 |
|
171 |
| - X = check_array(X) |
| 171 | + X = check_array(X, dtype=FLOAT_DTYPES) |
172 | 172 | if ridge_alpha != 'deprecated':
|
173 | 173 | warnings.warn("The ridge_alpha parameter on transform() is "
|
174 | 174 | "deprecated since 0.19 and will be removed in 0.21. "
|
@@ -285,7 +285,7 @@ def fit(self, X, y=None):
|
285 | 285 | Returns the instance itself.
|
286 | 286 | """
|
287 | 287 | random_state = check_random_state(self.random_state)
|
288 |
| - X = check_array(X) |
| 288 | + X = check_array(X, dtype=FLOAT_DTYPES) |
289 | 289 | if self.n_components is None:
|
290 | 290 | n_components = X.shape[1]
|
291 | 291 | else:
|
|
0 commit comments