8000 ENH Exposes latent mean and variance for GPCs (#22227) · jeremiedbb/scikit-learn@b55aba5 · GitHub
[go: up one dir, main page]

Skip to content

Commit b55aba5

Browse files
ENH Exposes latent mean and variance for GPCs (scikit-learn#22227)
Co-authored-by: antoinebaker <antoinebaker@users.noreply.github.com>
1 parent f29c100 commit b55aba5

File tree

4 files changed

+119
-11
lines changed

4 files changed

+119
-11
lines changed

doc/modules/gaussian_process.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ The :class:`GaussianProcessClassifier` implements Gaussian processes (GP) for
106106
classification purposes, more specifically for probabilistic classification,
107107
where test predictions take the form of class probabilities.
108108
GaussianProcessClassifier places a GP prior on a latent function :math:`f`,
109-
which is then squashed through a link function to obtain the probabilistic
109+
which is then squashed through a link function :math:`\pi` to obtain the probabilistic
110110
classification. The latent function :math:`f` is a so-called nuisance function,
111111
whose values are not observed and are not relevant by themselves.
112112
Its purpose is to allow a convenient formulation of the model, and :math:`f`
113-
is removed (integrated out) during prediction. GaussianProcessClassifier
113+
is removed (integrated out) during prediction. :class:`GaussianProcessClassifier`
114114
implements the logistic link function, for which the integral cannot be
115115
computed analytically but is easily approximated in the binary case.
116116

@@ -134,6 +134,11 @@ that have been chosen randomly from the range of allowed values.
134134
If the initial hyperparameters should be kept fixed, `None` can be passed as
135135
optimizer.
136136

137+
In some scenarios, information about the latent function :math:`f` is desired
138+
(i.e. the mean :math:`\bar{f_*}` and the variance :math:`\text{Var}[f_*]` described
139+
in Eqs. (3.21) and (3.24) of [RW2006]_). The :class:`GaussianProcessClassifier`
140+
provides access to these quantities via the `latent_mean_and_variance` method.
141+
137142
:class:`GaussianProcessClassifier` supports multi-class classification
138143
by performing either one-versus-rest or one-versus-one based training and
139144
prediction. In one-versus-rest, one binary Gaussian process classifier is
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- :class:`gaussian_process.GaussianProcessClassifier` now includes a `latent_mean_and_variance` method that exposes the mean and the variance of the latent function, :math:`f`, used in the Laplace approximation. By :user:`Miguel González Duque <miguelgondu>`

sklearn/gaussian_process/_gpc.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,9 @@ def predict_proba(self, X):
306306
"""
307307
check_is_fitted(self)
308308

309-
# Based on Algorithm 3.2 of GPML
310-
K_star = self.kernel_(self.X_train_, X) # K_star =k(x_star)
311-
f_star = K_star.T.dot(self.y_train_ - self.pi_) # Line 4
312-
v = solve(self.L_, self.W_sr_[:, np.newaxis] * K_star) # Line 5
313-
# Line 6 (compute np.diag(v.T.dot(v)) via einsum)
314-
var_f_star = self.kernel_.diag(X) - np.einsum("ij,ij->j", v, v)
309+
# Compute the mean and variance of the latent function
310+
# (Lines 4-6 of Algorithm 3.2 of GPML)
311+
latent_mean, latent_var = self.latent_mean_and_variance(X)
315312

316313
# Line 7:
317314
# Approximate \int log(z) * N(z | f_star, var_f_star)
@@ -320,12 +317,12 @@ def predict_proba(self, X):
320317
# sigmoid by a linear combination of 5 error functions.
321318
# For information on how this integral can be computed see
322319
# blitiri.blogspot.de/2012/11/gaussian-integral-of-error-function.html
323-
alpha = 1 / (2 * var_f_star)
324-
gamma = LAMBDAS * f_star
320+
alpha = 1 / (2 * latent_var)
321+
gamma = LAMBDAS * latent_mean
325322
integrals = (
326323
np.sqrt(np.pi / alpha)
327324
* erf(gamma * np.sqrt(alpha / (alpha + LAMBDAS**2)))
328-
/ (2 * np.sqrt(var_f_star * 2 * np.pi))
325+
/ (2 * np.sqrt(latent_var * 2 * np.pi))
329326
)
330327
pi_star = (COEFS * integrals).sum(axis=0) + 0.5 * COEFS.sum()
331328

@@ -410,6 +407,39 @@ def log_marginal_likelihood(
410407

411408
return Z, d_Z
412409

410+
def latent_mean_and_variance(self, X):
411+
"""Compute the mean and variance of the latent function values.
412+
413+
Based on algorithm 3.2 of [RW2006]_, this function returns the latent
414+
mean (Line 4) and variance (Line 6) of the Gaussian process
415+
classification model.
416+
417+
Note that this function is only supported for binary classification.
418+
419+
Parameters
420+
----------
421+
X : array-like of shape (n_samples, n_features) or list of object
422+
Query points where the GP is evaluated for classification.
423+
424+
Returns
425+
-------
426+
latent_mean : array-like of shape (n_samples,)
427+
Mean of the latent function values at the query points.
428+
429+
latent_var : array-like of shape (n_samples,)
430+
Variance of the latent function values at the query points.
431+
"""
432+
check_is_fitted(self)
433+
434+
# Based on Algorithm 3.2 of GPML
435+
K_star = self.kernel_(self.X_train_, X) # K_star =k(x_star)
436+
latent_mean = K_star.T.dot(self.y_train_ - self.pi_) # Line 4
437+
v = solve(self.L_, self.W_sr_[:, np.newaxis] * K_star) # Line 5
438+
# Line 6 (compute np.diag(v.T.dot(v)) via einsum)
439+
latent_var = self.kernel_.diag(X) - np.einsum("ij,ij->j", v, v)
440+
441+
return latent_mean, latent_var
442+
413443
def _posterior_mode(self, K, return_temporaries=False):
414444
"""Mode-finding for binary Laplace GPC and fixed kernel.
415445
@@ -902,3 +932,40 @@ def log_marginal_likelihood(
902932
"Obtained theta with shape %d."
903933
% (n_dims, n_dims * self.classes_.shape[0], theta.shape[0])
904934
)
935+
936+
def latent_mean_and_variance(self, X):
937+
"""Compute the mean and variance of the latent function.
938+
939+
Based on algorithm 3.2 of [RW2006]_, this function returns the latent
940+
mean (Line 4) and variance (Line 6) of the Gaussian process
941+
classification model.
942+
943+
Note that this function is only supported for binary classification.
944+
945+
Parameters
946+
----------
947+
X : array-like of shape (n_samples, n_features) or list of object
948+
Query points where the GP is evaluated for classification.
949+
950+
Returns
951+
-------
952+
latent_mean : array-like of shape (n_samples,)
953+
Mean of the latent function values at the query points.
954+
955+
latent_var : array-like of shape (n_samples,)
956+
Variance of the latent function values at the query points.
957+
"""
958+
if self.n_classes_ > 2:
959+
raise ValueError(
960+
"Returning the mean and variance of the latent function f "
961+
"is only supported for binary classification, received "
962+
f"{self.n_classes_} classes."
963+
)
964+
check_is_fitted(self)
965+
966+
if self.kernel is None or self.kernel.requires_vector_input:
967+
X = validate_data(self, X, ensure_2d=True, dtype="numeric", reset=False)
968+
else:
969+
X = validate_data(self, X, ensure_2d=False, dtype=None, reset=False)
970+
971+
return self.base_estimator_.latent_mean_and_variance(X)

sklearn/gaussian_process/tests/test_gpc.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,38 @@ def test_gpc_fit_error(params, error_type, err_msg):
283283
gpc = GaussianProcessClassifier(**params)
284284
with pytest.raises(error_type, match=err_msg):
285285
gpc.fit(X, y)
286+
287+
288+
@pytest.mark.parametrize("kernel", kernels)
289+
def test_gpc_latent_mean_and_variance_shape(kernel):
290+
"""Checks that the latent mean and variance have the right shape."""
291+
gpc = GaussianProcessClassifier(kernel=kernel)
292+
gpc.fit(X, y)
293+
294+
# Check that the latent mean and variance have the right shape
295+
latent_mean, latent_variance = gpc.latent_mean_and_variance(X)
296+
assert latent_mean.shape == (X.shape[0],)
297+
assert latent_variance.shape == (X.shape[0],)
298+
299+
300+
def test_gpc_latent_mean_and_variance_complain_on_more_than_2_classes():
301+
"""Checks that the latent mean and variance have the right shape."""
302+
gpc = GaussianProcessClassifier(kernel=RBF())
303+
gpc.fit(X, y_mc)
304+
305+
# Check that the latent mean and variance have the right shape
306+
with pytest.raises(
307+
ValueError,
308+
match="Returning the mean and variance of the latent function f "
309+
"is only supported for binary classification",
310+
):
311+
gpc.latent_mean_and_variance(X)
312+
313+
314+
def test_latent_mean_and_variance_works_on_structured_kernels():
315+
X = ["A", "AB", "B"]
316+
y = np.array([True, False, True])
317+
kernel = MiniSeqKernel(baseline_similarity_bounds="fixed")
318+
gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y)
319+
320+
gpc.latent_mean_and_variance(X)

0 commit comments

Comments
 (0)
0