@@ -306,12 +306,9 @@ def predict_proba(self, X):
306
306
"""
307
307
check_is_fitted (self )
308
308
309
- # Based on Algorithm 3.2 of GPML
310
- K_star = self .kernel_ (self .X_train_ , X ) # K_star =k(x_star)
311
- f_star = K_star .T .dot (self .y_train_ - self .pi_ ) # Line 4
312
- v = solve (self .L_ , self .W_sr_ [:, np .newaxis ] * K_star ) # Line 5
313
- # Line 6 (compute np.diag(v.T.dot(v)) via einsum)
314
- var_f_star = self .kernel_ .diag (X ) - np .einsum ("ij,ij->j" , v , v )
309
+ # Compute the mean and variance of the latent function
310
+ # (Lines 4-6 of Algorithm 3.2 of GPML)
311
+ latent_mean , latent_var = self .latent_mean_and_variance (X )
315
312
316
313
# Line 7:
317
314
# Approximate \int log(z) * N(z | f_star, var_f_star)
@@ -320,12 +317,12 @@ def predict_proba(self, X):
320
317
# sigmoid by a linear combination of 5 error functions.
321
318
# For information on how this integral can be computed see
322
319
# blitiri.blogspot.de/2012/11/gaussian-integral-of-error-function.html
323
- alpha = 1 / (2 * var_f_star )
324
- gamma = LAMBDAS * f_star
320
+ alpha = 1 / (2 * latent_var )
321
+ gamma = LAMBDAS * latent_mean
325
322
integrals = (
326
323
np .sqrt (np .pi / alpha )
327
324
* erf (gamma * np .sqrt (alpha / (alpha + LAMBDAS ** 2 )))
328
- / (2 * np .sqrt (var_f_star * 2 * np .pi ))
325
+ / (2 * np .sqrt (latent_var * 2 * np .pi ))
329
326
)
330
327
pi_star = (COEFS * integrals ).sum (axis = 0 ) + 0.5 * COEFS .sum ()
331
328
@@ -410,6 +407,39 @@ def log_marginal_likelihood(
410
407
411
408
return Z , d_Z
412
409
410
+ def latent_mean_and_variance (self , X ):
411
+ """Compute the mean and variance of the latent function values.
412
+
413
+ Based on algorithm 3.2 of [RW2006]_, this function returns the latent
414
+ mean (Line 4) and variance (Line 6) of the Gaussian process
415
+ classification model.
416
+
417
+ Note that this function is only supported for binary classification.
418
+
419
+ Parameters
420
+ ----------
421
+ X : array-like of shape (n_samples, n_features) or list of object
422
+ Query points where the GP is evaluated for classification.
423
+
424
+ Returns
425
+ -------
426
+ latent_mean : array-like of shape (n_samples,)
427
+ Mean of the latent function values at the query points.
428
+
429
+ latent_var : array-like of shape (n_samples,)
430
+ Variance of the latent function values at the query points.
431
+ """
432
+ check_is_fitted (self )
433
+
434
+ # Based on Algorithm 3.2 of GPML
435
+ K_star = self .kernel_ (self .X_train_ , X ) # K_star =k(x_star)
436
+ latent_mean = K_star .T .dot (self .y_train_ - self .pi_ ) # Line 4
437
+ v = solve (self .L_ , self .W_sr_ [:, np .newaxis ] * K_star ) # Line 5
438
+ # Line 6 (compute np.diag(v.T.dot(v)) via einsum)
439
+ latent_var = self .kernel_ .diag (X ) - np .einsum ("ij,ij->j" , v , v )
440
+
441
+ return latent_mean , latent_var
442
+
413
443
def _posterior_mode (self , K , return_temporaries = False ):
414
444
"""Mode-finding for binary Laplace GPC and fixed kernel.
415
445
@@ -902,3 +932,40 @@ def log_marginal_likelihood(
902
932
"Obtained theta with shape %d."
903
933
% (n_dims , n_dims * self .classes_ .shape [0 ], theta .shape [0 ])
904
934
)
935
+
936
+ def latent_mean_and_variance (self , X ):
937
+ """Compute the mean and variance of the latent function.
938
+
939
+ Based on algorithm 3.2 of [RW2006]_, this function returns the latent
940
+ mean (Line 4) and variance (Line 6) of the Gaussian process
941
+ classification model.
942
+
943
+ Note that this function is only supported for binary classification.
944
+
945
+ Parameters
946
+ ----------
947
+ X : array-like of shape (n_samples, n_features) or list of object
948
+ Query points where the GP is evaluated for classification.
949
+
950
+ Returns
951
+ -------
952
+ latent_mean : array-like of shape (n_samples,)
953
+ Mean of the latent function values at the query points.
954
+
955
+ latent_var : array-like of shape (n_samples,)
956
+ Variance of the latent function values at the query points.
957
+ """
958
+ if self .n_classes_ > 2 :
959
+ raise ValueError (
960
+ "Returning the mean and variance of the latent function f "
961
+ "is only supported for binary classification, received "
962
+ f"{ self .n_classes_ } classes."
963
+ )
964
+ check_is_fitted (self )
965
+
966
+ if self .kernel is None or self .kernel .requires_vector_input :
967
+ X = validate_data (self , X , ensure_2d = True , dtype = "numeric" , reset = False )
968
+ else :
969
+ X = validate_data (self , X , ensure_2d = False , dtype = None , reset = False )
970
+
971
+ return self .base_estimator_ .latent_mean_and_variance (X )
0 commit comments