8000 revert pls changes · scikit-learn/scikit-learn@e9ea2d1 · GitHub
[go: up one dir, main page]

Skip to content

Commit e9ea2d1

Browse files
Maria AndriopoulouMaria Andriopoulou
Maria Andriopoulou
authored and
Maria Andriopoulou
committed
revert pls changes
1 parent c74d4ee commit e9ea2d1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

sklearn/cross_decomposition/pls_.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,14 +394,14 @@ def transform(self, X, Y=None, copy=True):
394394
x_scores if Y is not given, (x_scores, y_scores) otherwise.
395395
"""
396396
check_is_fitted(self, 'x_mean_')
397-
X = check_array(X, copy=copy, dtype=np.float64)
397+
X = check_array(X, copy=copy, dtype=FLOAT_DTYPES)
398398
# Normalize
399399
X -= self.x_mean_
400400
X /= self.x_std_
401401
# Apply rotation
402402
x_scores = np.dot(X, self.x_rotations_)
403403
if Y is not None:
404-
Y = check_array(Y, ensure_2d=False, copy=copy, dtype=np.float64)
404+
Y = check_array(Y, ensure_2d=False, copy=copy, dtype=FLOAT_DTYPES)
405405
if Y.ndim == 1:
406406
Y = Y.reshape(-1, 1)
407407
Y -= self.y_mean_
@@ -429,7 +429,7 @@ def predict(self, X, copy=True):
429429
be an issue in high dimensional space.
430430
"""
431431
check_is_fitted(self, 'x_mean_')
432-
X = check_array(X, copy=copy, dtype=np.float64)
432+
X = check_array(X, copy=copy, dtype=FLOAT_DTYPES)
433433
# Normalize
434434
X -= self.x_mean_
435435
X /= self.x_std_

0 commit comments

Comments
 (0)
0