8000 FIX compute y_std and y_cov properly with multi-target in GPR (#20761) · scikit-learn/scikit-learn@9b210ae · GitHub
[go: up one dir, main page]

Skip to content

Commit 9b210ae

Browse files
authored
FIX compute y_std and y_cov properly with mu 8000 lti-target in GPR (#20761)
1 parent 337e0d2 commit 9b210ae

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

doc/whats_new/v1.0.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ Fixed models
4343
between sparse and dense input. :pr:`21195`
4444
by :user:`Jérémie du Boisberranger <jeremiedbb>`.
4545

46+
:mod:`sklearn.gaussian_process`
47+
...............................
48+
49+
- |Fix| Compute `y_std` properly with multi-target in
50+
:class:`sklearn.gaussian_process.GaussianProcessRegressor` allowing
51+
proper normalization in multi-target scene.
52+
:pr:`20761` by :user:`Patrick de C. T. R. Ferreira <patrickctrf>`.
53+
4654
:mod:`sklearn.feature_extraction`
4755
.................................
4856

sklearn/gaussian_process/_gpr.py

Lines changed: 18 additions & 4 deletions< 8000 div class="d-flex">
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,12 @@ def predict(self, X, return_std=False, return_cov=False):
349349
y_mean : ndarray of shape (n_samples,) or (n_samples, n_targets)
350350
Mean of predictive distribution a query points.
351351
352-
y_std : ndarray of shape (n_samples,), optional
352+
y_std : ndarray of shape (n_samples,) or (n_samples, n_targets), optional
353353
Standard deviation of predictive distribution at query points.
354354
Only returned when `return_std` is True.
355355
356-
y_cov : ndarray of shape (n_samples, n_samples), optional
356+
y_cov : ndarray of shape (n_samples, n_samples) or \
357+
(n_samples, n_samples, n_targets), optional
357358
Covariance of joint predictive distribution a query points.
358359
Only returned when `return_cov` is True.
359360
"""
@@ -403,7 +404,14 @@ def predict(self, X, return_std=False, return_cov=False):
403404
y_cov = self.kernel_(X) - V.T @ V
404405

405406
# undo normalisation
406-
y_cov = y_cov * self._y_train_std ** 2
407+
y_cov = np.outer(y_cov, self._y_train_std ** 2). 8000 reshape(
408+
*y_cov.shape, -1
409+
)
410+
411+
# if y_cov has shape (n_samples, n_samples, 1), reshape to
412+
# (n_samples, n_samples)
413+
if y_cov.shape[2] == 1:
414+
y_cov = np.squeeze(y_cov, axis=2)
407415

408416
return y_mean, y_cov
409417
elif return_std:
@@ -424,7 +432,13 @@ def predict(self, X, return_std=False, return_cov=False):
424432
y_var[y_var_negative] = 0.0
425433

426434
# undo normalisation
427-
y_var = y_var * self._y_train_std ** 2
435+
y_var = np.outer(y_var, self._y_train_std ** 2).reshape(
436+
*y_var.shape, -1
437+
)
438+
439+
# if y_var has shape (n_samples, 1), reshape to (n_samples,)
440+
if y_var.shape[1] == 1:
441+
y_var = np.squeeze(y_var, axis=1)
428442

429443
return y_mean, np.sqrt(y_var)
430444
else:

sklearn/gaussian_process/tests/test_gpr.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,33 @@ def test_gpr_predict_error():
652652
err_msg = "At most one of return_std or return_cov can be requested."
653653
with pytest.raises(RuntimeError, match=err_msg):
654654
gpr.predict(X, return_cov=True, return_std=True)
655+
656+
657+
def test_y_std_with_multitarget_normalized():
658+
"""Check the proper normalization of `y_std` and `y_cov` in multi-target scene.
659+
660+
Non-regression test for:
661+
https://github.com/scikit-learn/scikit-learn/issues/17394
662+
https://github.com/scikit-learn/scikit-learn/issues/18065
663+
"""
664+
rng = np.random.RandomState(1234)
665+
666+
n_samples, n_features, n_targets = 12, 10, 6
667+
668+
X_train = rng.randn(n_samples, n_features)
669+
y_train = rng.randn(n_samples, n_targets)
670+
X_test = rng.randn(n_samples, n_features)
671+
672+
# Generic kernel
673+
kernel = WhiteKernel(1.0, (1e-1, 1e3)) * C(10.0, (1e-3, 1e3))
674+
675+
model = GaussianProcessRegressor(
676+
kernel=kernel, n_restarts_optimizer=10, alpha=0.1, normalize_y=True
677+
)
678+
model.fit(X_train, y_train)
679+
y_pred, y_std = model.predict(X_test, return_std=True)
680+
_, y_cov = model.predict(X_test, return_cov=True)
681+
682+
assert y_pred.shape == (n_samples, n_targets)
683+
assert y_std.shape == (n_samples, n_targets)
684+
assert y_cov.shape == (n_samples, n_samples, n_targets)

0 commit comments

Comments
 (0)
0