8000 Fix PLS scaling bug (#7819) · scikit-learn/scikit-learn@b15818e · GitHub
[go: up one dir, main page]

Skip to content

Commit b15818e

Browse files
jayzed82lesteve
authored andcommitted
Fix PLS scaling bug (#7819)
1 parent 0bdd8bf commit b15818e

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

doc/whats_new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,8 @@ Bug fixes
345345

346346
- Fixed a memory leak in our LibLinear implementation. :issue:`9024` by
347347
:user:`Sergei Lebedev <superbobry>`
348+
- Fixed improper scaling in :class:`sklearn.cross_decomposition.PLSRegression`
349+
with ``scale=True``. :issue:`7819` by :user:`jayzed82 <jayzed82>`.
348350

349351
- Fixed oob_score in :class:`ensemble.BaggingClassifier`.
350352
:issue:`#8936` by :user:`mlewis1729 <mlewis1729>`

sklearn/cross_decomposition/pls_.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,7 @@ def fit(self, X, Y):
366366
# Y = X W(P'W)^-1Q' + Err = XB + Err
367367
# => B = W*Q' (p x q)
368368
self.coef_ = np.dot(self.x_rotations_, self.y_loadings_.T)
369-
self.coef_ = (1. / self.x_std_.reshape((p, 1)) * self.coef_ *
370-
self.y_std_)
369+
self.coef_ = self.coef_ * self.y_std_
371370
return self
372371

373372
def transform(self, X, Y=None, copy=True):

sklearn/cross_decomposition/tests/test_pls.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import numpy as np
2+
from numpy.testing import assert_approx_equal
3+
24
from sklearn.utils.testing import (assert_equal, assert_array_almost_equal,
35
assert_array_equal, assert_true,
46
assert_raise_message)
57
from sklearn.datasets import load_linnerud
68
from sklearn.cross_decomposition import pls_, CCA
9+
from sklearn.preprocessing import StandardScaler
710

811

912
def test_pls():
@@ -351,11 +354,38 @@ def test_scale_and_stability():
351354
assert_array_almost_equal(X_s_score, X_score)
352355
assert_array_almost_equal(Y_s_score, Y_score)
353356

357+
354358
def test_pls_errors():
355359
d = load_linnerud()
356360
X = d.data
357361
Y = d.target
358362
for clf in [pls_.PLSCanonical(), pls_.PLSRegression(),
359363
pls_.PLSSVD()]:
360364
clf.n_components = 4
361-
assert_raise_message(ValueError, "Invalid number of components", clf.fit, X, Y)
365+
assert_raise_message(ValueError, "Invalid number of components",
366+
clf.fit, X, Y)
367+
368+
369+
def test_pls_scaling():
370+
# sanity check for scale=True
371+
n_samples = 1000
372+
n_targets = 5
373+
n_features = 10
374+
375+
rng = np.random.RandomState(0)
376+
377+
Q = rng.randn(n_targets, n_features)
378+
Y = rng.randn(n_samples, n_targets)
379+
X = np.dot(Y, Q) + 2 * rng.randn(n_samples, n_features) + 1
380+
X *= 1000
381+
X_scaled = StandardScaler().fit_transform(X)
382+
383+
pls = pls_.PLSRegression(n_components=5, scale=True)
384+
385+
pls.fit(X, Y)
386+
score = pls.score(X, Y)
387+
388+
pls.fit(X_scaled, Y)
389+
score_scaled = pls.score(X_scaled, Y)
390+
391+
assert_approx_equal(score, score_scaled)

0 commit comments

Comments
 (0)
0