8000 API add intercept_ attribute to PLS estimators (#22015) · scikit-learn/scikit-learn@452ede0 · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit 452ede0

Browse files
glemaitrethomasjpfanjeremiedbb
authored
API add intercept_ attribute to PLS estimators (#22015)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 32c53bc commit 452ede0

File tree

3 files changed

+54
-4
lines changed

3 files changed

+54
-4
lines changed

doc/whats_new/v1.1.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,13 @@ Changelog
299299
specific shape for `coef_` (e.g. :class:`feature_selection.RFE`).
300300
:pr:`22016` by :user:`Guillaume Lemaitre <glemaitre>`.
301301

302+
- |API| add the fitted attribute `intercept_` to
303+
:class:`cross_decomposition.PLSCanonical`,
304+
:class:`cross_decomposition.PLSRegression`, and
305+
:class:`cross_decomposition.CCA`. The method `predict` is indeed equivalent to
306+
`Y = X @ coef_ + intercept_`.
307+
:pr:`22015` by :user:`Guillaume Lemaitre <glemaitre>`.
308+
302309
:mod:`sklearn.datasets`
303310
.......................
304311

sklearn/cross_decomposition/_pls.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def fit(self, X, Y):
358358
# TODO(1.3): change `self._coef_` to `self.coef_`
359359
self._coef_ = np.dot(self.x_rotations_, self.y_loadings_.T)
360360
self._coef_ = (self._coef_ * self._y_std).T
361+
self.intercept_ = self._y_mean
361362
self._n_features_out = self.x_rotations_.shape[1]
362363
return self
363364

@@ -473,7 +474,7 @@ def predict(self, X, copy=True):
473474
X /= self._x_std
474475
# TODO(1.3): change `self._coef_` to `self.coef_`
475476
Ypred = X @ self._coef_.T
476-
return Ypred + self._y_mean
477+
return Ypred + self.intercept_
477478

478479
def fit_transform(self, X, y=None):
479480
"""Learn and apply the dimension reduction on the train data.
@@ -501,6 +502,7 @@ def coef_(self):
501502
# TODO(1.3): remove and change `self._coef_` to `self.coef_`
502503
# remove catch warnings from `_get_feature_importances`
503504
# delete self._coef_no_warning
505+
# update the docstring of `coef_` and `intercept_` attribute
504506
if hasattr(self, "_coef_") and getattr(self, "_coef_warning", True):
505507
warnings.warn(
506508
"The attribute `coef_` will be transposed in version 1.3 to be "
@@ -581,7 +583,13 @@ class PLSRegression(_PLS):
581583
582584
coef_ : ndarray of shape (n_features, n_targets)
583585
The coefficients of the linear model such that `Y` is approximated as
584-
`Y = X @ coef_`.
586+
`Y = X @ coef_ + intercept_`.
587+
588+
intercept_ : ndarray of shape (n_targets,)
589+
The intercepts of the linear model such that `Y` is approximated as
590+
`Y = X @ coef_ + intercept_`.
591+
592+
.. versionadded:: 1.1
585593
586594
n_iter_ : list of shape (n_components,)
587595
Number of iterations of the power method, for each
@@ -715,7 +723,13 @@ class PLSCanonical(_PLS):
715723
716724
coef_ : ndarray of shape (n_features, n_targets)
717725
The coefficients of the linear model such that `Y` is approximated as
718-
`Y = X @ coef_`.
726+
`Y = X @ coef_ + intercept_`.
727+
728+
intercept_ : ndarray of shape (n_targets,)
729+
The intercepts of the linear model such that `Y` is approximated as
730+
`Y = X @ coef_ + intercept_`.
731+
732+
.. versionadded:: 1.1
719733
720734
n_iter_ : list of shape (n_components,)
721735
Number of iterations of the power method, for each
@@ -827,7 +841,13 @@ class CCA(_PLS):
827841
828842
coef_ : ndarray of shape (n_features, n_targets)
829843
The coefficients of the linear model such that `Y` is approximated as
830-
`Y = X @ coef_`.
844+
`Y = X @ coef_ + intercept_`.
845+
846+
intercept_ : ndarray of shape (n_targets,)
847+
The intercepts of the linear model such that `Y` is approximated as
848+
`Y = X @ coef_ + intercept_`.
849+
850+
.. versionadded:: 1.1
831851
832852
n_iter_ : list of shape (n_components,)
833853
Number of iterations of the power method, for each

sklearn/cross_decomposition/tests/test_pls.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,29 @@ def test_pls_coef_shape(PLSEstimator):
616616
assert pls._coef_.shape == (Y.shape[1], X.shape[1])
617617

618618

619+
# TODO (1.3): remove the filterwarnings and adapt the dot product between `X_trans` and
620+
# `pls.coef_`
621+
@pytest.mark.filterwarnings("ignore:The attribute `coef_` will be transposed")
622+
@pytest.mark.parametrize("scale", [True, False])
623+
@pytest.mark.parametrize("PLSEstimator", [PLSRegression, PLSCanonical, CCA])
624+
def test_pls_prediction(PLSEstimator, scale):
625+
"""Check the behaviour of the prediction function."""
626+
d = load_linnerud()
627+
X = d.data
628+
Y = d.target
629+
630+
pls = PLSEstimator(copy=True, scale=scale).fit(X, Y)
631+
Y_pred = pls.predict(X, copy=True)
632+
633+
y_mean = Y.mean(axis=0)
634+
X_trans = X - X.mean(axis=0)
635+
if scale:
636+
X_trans /= X.std(axis=0, ddof=1)
637+
638+
assert_allclose(pls.intercept_, y_mean)
639+
assert_allclose(Y_pred, X_trans @ pls.coef_ + pls.intercept_)
640+
641+
619642
@pytest.mark.parametrize("Klass", [CCA, PLSSVD, PLSRegression, PLSCanonical])
620643
def test_pls_feature_names_out(Klass):
621644
"""Check `get_feature_names_out` cross_decomposition module."""

0 commit comments

Comments
 (0)
0