From 510638878e6f72b055fe7555c7e8a9782df78962 Mon Sep 17 00:00:00 2001 From: Antoine Baker Date: Tue, 14 Jan 2025 12:11:28 +0100 Subject: [PATCH 1/9] fix log marginal likelihood and updates --- sklearn/linear_model/_bayes.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index b6527d4f22b1f..d8a63e0583783 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -264,14 +264,19 @@ def fit(self, X, y, sample_weight=None): self.X_scale_ = X_scale_ n_samples, n_features = X.shape + sw_sum = n_samples + if sample_weight is not None: + sw_sum = sample_weight.sum() + # Initialization of the values of the parameters eps = np.finfo(np.float64).eps - # Add `eps` in the denominator to omit division by zero if `np.var(y)` - # is zero + # Add `eps` in the denominator to omit division by zero + # if y_weighted_var is zero alpha_ = self.alpha_init lambda_ = self.lambda_init if alpha_ is None: - alpha_ = 1.0 / (np.var(y) + eps) + y_weighted_var = (y**2).sum() / sw_sum + alpha_ = 1.0 / (y_weighted_var + eps) if lambda_ is None: lambda_ = 1.0 @@ -309,7 +314,7 @@ def fit(self, X, y, sample_weight=None): # Update alpha and lambda according to (MacKay, 1992) gamma_ = np.sum((alpha_ * eigen_vals_) / (lambda_ + alpha_ * eigen_vals_)) lambda_ = (gamma_ + 2 * lambda_1) / (np.sum(coef_**2) + 2 * lambda_2) - alpha_ = (n_samples - gamma_ + 2 * alpha_1) / (rmse_ + 2 * alpha_2) + alpha_ = (sw_sum - gamma_ + 2 * alpha_1) / (rmse_ + 2 * alpha_2) # Check for convergence if iter_ != 0 and np.sum(np.abs(coef_old_ - coef_)) < self.tol: @@ -330,7 +335,14 @@ def fit(self, X, y, sample_weight=None): if self.compute_score: # compute the log marginal likelihood s = self._log_marginal_likelihood( - n_samples, n_features, eigen_vals_, alpha_, lambda_, coef_, rmse_ + n_samples, + n_features, + sw_sum, + eigen_vals_, + alpha_, + lambda_, + coef_, + rmse_, ) self.scores_.append(s) self.scores_ = np.array(self.scores_) @@ -399,7 +411,7 @@ def _update_coef_( return coef_, rmse_ def _log_marginal_likelihood( - self, n_samples, n_features, eigen_vals, alpha_, lambda_, coef, rmse + self, n_samples, n_features, sw_sum, eigen_vals, alpha_, lambda_, coef, rmse ): """Log marginal likelihood.""" alpha_1 = self.alpha_1 @@ -421,11 +433,11 @@ def _log_marginal_likelihood( score += alpha_1 * log(alpha_) - alpha_2 * alpha_ score += 0.5 * ( n_features * log(lambda_) - + n_samples * log(alpha_) + + sw_sum * log(alpha_) - alpha_ * rmse - lambda_ * np.sum(coef**2) + logdet_sigma - - n_samples * log(2 * np.pi) + - sw_sum * log(2 * np.pi) ) return score From 1fafa3b2e712a6cd977f2c2856641858044c0f66 Mon Sep 17 00:00:00 2001 From: Antoine Baker Date: Tue, 14 Jan 2025 12:18:07 +0100 Subject: [PATCH 2/9] rename rsme mse --- sklearn/linear_model/_bayes.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index d8a63e0583783..658490f39a882 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -300,21 +300,21 @@ def fit(self, X, y, sample_weight=None): # Convergence loop of the bayesian ridge regression for iter_ in range(self.max_iter): # update posterior mean coef_ based on alpha_ and lambda_ and - # compute corresponding rmse - coef_, rmse_ = self._update_coef_( + # compute corresponding mse + coef_, mse_ = self._update_coef_( X, y, n_samples, n_features, XT_y, U, Vh, eigen_vals_, alpha_, lambda_ ) if self.compute_score: # compute the log marginal likelihood s = self._log_marginal_likelihood( - n_samples, n_features, eigen_vals_, alpha_, lambda_, coef_, rmse_ + n_samples, n_features, eigen_vals_, alpha_, lambda_, coef_, mse_ ) self.scores_.append(s) # Update alpha and lambda according to (MacKay, 1992) gamma_ = np.sum((alpha_ * eigen_vals_) / (lambda_ + alpha_ * eigen_vals_)) lambda_ = (gamma_ + 2 * lambda_1) / (np.sum(coef_**2) + 2 * lambda_2) - alpha_ = (sw_sum - gamma_ + 2 * alpha_1) / (rmse_ + 2 * alpha_2) + alpha_ = (sw_sum - gamma_ + 2 * alpha_1) / (mse_ + 2 * alpha_2) # Check for convergence if iter_ != 0 and np.sum(np.abs(coef_old_ - coef_)) < self.tol: @@ -329,7 +329,7 @@ def fit(self, X, y, sample_weight=None): # log marginal likelihood and posterior covariance self.alpha_ = alpha_ self.lambda_ = lambda_ - self.coef_, rmse_ = self._update_coef_( + self.coef_, mse_ = self._update_coef_( X, y, n_samples, n_features, XT_y, U, Vh, eigen_vals_, alpha_, lambda_ ) if self.compute_score: @@ -342,7 +342,7 @@ def fit(self, X, y, sample_weight=None): alpha_, lambda_, coef_, - rmse_, + mse_, ) self.scores_.append(s) self.scores_ = np.array(self.scores_) @@ -390,7 +390,7 @@ def predict(self, X, return_std=False): def _update_coef_( self, X, y, n_samples, n_features, XT_y, U, Vh, eigen_vals_, alpha_, lambda_ ): - """Update posterior mean and compute corresponding rmse. + """Update posterior mean and compute corresponding mse. Posterior mean is given by coef_ = scaled_sigma_ * X.T * y where scaled_sigma_ = (lambda_/alpha_ * np.eye(n_features) @@ -406,12 +406,12 @@ def _update_coef_( [X.T, U / (eigen_vals_ + lambda_ / alpha_)[None, :], U.T, y] ) - rmse_ = np.sum((y - np.dot(X, coef_)) ** 2) + mse_ = np.sum((y - np.dot(X, coef_)) ** 2) - return coef_, rmse_ + return coef_, mse_ def _log_marginal_likelihood( - self, n_samples, n_features, sw_sum, eigen_vals, alpha_, lambda_, coef, rmse + self, n_samples, n_features, sw_sum, eigen_vals, alpha_, lambda_, coef, mse ): """Log marginal likelihood.""" alpha_1 = self.alpha_1 @@ -434,7 +434,7 @@ def _log_marginal_likelihood( score += 0.5 * ( n_features * log(lambda_) + sw_sum * log(alpha_) - - alpha_ * rmse + - alpha_ * mse - lambda_ * np.sum(coef**2) + logdet_sigma - sw_sum * log(2 * np.pi) @@ -696,14 +696,12 @@ def update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_): coef_ = update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_) # Update alpha and lambda - rmse_ = np.sum((y - np.dot(X, coef_)) ** 2) + mse_ = np.sum((y - np.dot(X, coef_)) ** 2) gamma_ = 1.0 - lambda_[keep_lambda] * np.diag(sigma_) lambda_[keep_lambda] = (gamma_ + 2.0 * lambda_1) / ( (coef_[keep_lambda]) ** 2 + 2.0 * lambda_2 ) - alpha_ = (n_samples - gamma_.sum() + 2.0 * alpha_1) / ( - rmse_ + 2.0 * alpha_2 - ) + alpha_ = (n_samples - gamma_.sum() + 2.0 * alpha_1) / (mse_ + 2.0 * alpha_2) # Prune the weights with a precision over a threshold keep_lambda = lambda_ < self.threshold_lambda @@ -718,7 +716,7 @@ def update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_): + n_samples * log(alpha_) + np.sum(np.log(lambda_)) ) - s -= 0.5 * (alpha_ * rmse_ + (lambda_ * coef_**2).sum()) + s -= 0.5 * (alpha_ * mse_ + (lambda_ * coef_**2).sum()) self.scores_.append(s) # Check for convergence From ec7159ae6dd4c640f760863112d8a02a25e96c11 Mon Sep 17 00:00:00 2001 From: Antoine Baker Date: Tue, 14 Jan 2025 14:56:17 +0100 Subject: [PATCH 3/9] clean xfail check --- sklearn/utils/_test_common/instance_generator.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/sklearn/utils/_test_common/instance_generator.py b/sklearn/utils/_test_common/instance_generator.py index c46213b417090..c5e98c1a0d5dc 100644 --- a/sklearn/utils/_test_common/instance_generator.py +++ b/sklearn/utils/_test_common/instance_generator.py @@ -826,15 +826,6 @@ def _yield_instances_for_check(check, estimator_orig): "sample_weight is not equivalent to removing/repeating samples." ), }, - BayesianRidge: { - # TODO: fix sample_weight handling of this estimator, see meta-issue #16298 - "check_sample_weight_equivalence_on_dense_data": ( - "sample_weight is not equivalent to removing/repeating samples." - ), - "check_sample_weight_equivalence_on_sparse_data": ( - "sample_weight is not equivalent to removing/repeating samples." - ), - }, BernoulliRBM: { "check_methods_subset_invariance": ("fails for the decision_function method"), "check_methods_sample_order_invariance": ("fails for the score_samples method"), From aca84ccd2510a99bd786d39af8d3f9731624d9c1 Mon Sep 17 00:00:00 2001 From: Antoine Baker Date: Tue, 14 Jan 2025 15:04:51 +0100 Subject: [PATCH 4/9] changelog --- .../upcoming_changes/sklearn.linear_model/30644.fix.rst | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/sklearn.linear_model/30644.fix.rst diff --git a/doc/whats_new/upcoming_changes/sklearn.linear_model/30644.fix.rst b/doc/whats_new/upcoming_changes/sklearn.linear_model/30644.fix.rst new file mode 100644 index 0000000000000..c9254fe350e28 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.linear_model/30644.fix.rst @@ -0,0 +1,3 @@ +- The update and initialization of the hyperparameters now properly handle + sample weights in :class:`linear_model.BayesianRidge`. + By :user:`Antoine Baker `. From cc235a57922f669c490032892bc7962c0f30267e Mon Sep 17 00:00:00 2001 From: Antoine Baker Date: Tue, 14 Jan 2025 15:54:31 +0100 Subject: [PATCH 5/9] fix scores --- sklearn/linear_model/_bayes.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index 658490f39a882..d904392c99502 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -307,7 +307,14 @@ def fit(self, X, y, sample_weight=None): if self.compute_score: # compute the log marginal likelihood s = self._log_marginal_likelihood( - n_samples, n_features, eigen_vals_, alpha_, lambda_, coef_, mse_ + n_samples, + n_features, + sw_sum, + eigen_vals_, + alpha_, + lambda_, + coef_, + mse_, ) self.scores_.append(s) From 51854b21e3774b6036f2963139d1b8aba463993c Mon Sep 17 00:00:00 2001 From: Antoine Baker Date: Wed, 15 Jan 2025 17:52:47 +0100 Subject: [PATCH 6/9] fix weighted variance --- sklearn/linear_model/_bayes.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index d904392c99502..5c3b231a290ab 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -244,9 +244,15 @@ def fit(self, X, y, sample_weight=None): y_numeric=True, ) dtype = X.dtype + n_samples, n_features = X.shape + sw_sum = n_samples + y_var = y.var() if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X, dtype=dtype) + sw_sum = sample_weight.sum() + y_mean = np.average(y, weights=sample_weight) + y_var = np.average((y - y_mean) ** 2, weights=sample_weight) X, y, X_offset_, y_offset_, X_scale_ = _preprocess_data( X, @@ -262,21 +268,14 @@ def fit(self, X, y, sample_weight=None): self.X_offset_ = X_offset_ self.X_scale_ = X_scale_ - n_samples, n_features = X.shape - - sw_sum = n_samples - if sample_weight is not None: - sw_sum = sample_weight.sum() # Initialization of the values of the parameters eps = np.finfo(np.float64).eps # Add `eps` in the denominator to omit division by zero - # if y_weighted_var is zero alpha_ = self.alpha_init lambda_ = self.lambda_init if alpha_ is None: - y_weighted_var = (y**2).sum() / sw_sum - alpha_ = 1.0 / (y_weighted_var + eps) + alpha_ = 1.0 / (y_var + eps) if lambda_ is None: lambda_ = 1.0 From e7ddaeb0172ef1a85088532bcafbbda38ff68db1 Mon Sep 17 00:00:00 2001 From: antoinebaker Date: Thu, 16 Jan 2025 17:38:26 +0100 Subject: [PATCH 7/9] Update sklearn/linear_model/_bayes.py Co-authored-by: Olivier Grisel --- sklearn/linear_model/_bayes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index 5c3b231a290ab..85b494eeeb057 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -412,6 +412,8 @@ def _update_coef_( [X.T, U / (eigen_vals_ + lambda_ / alpha_)[None, :], U.T, y] ) + # Note: we do not need to explicit use the weights in this sum because + # y and X were proprocessed by _rescale_data to handle the weights. mse_ = np.sum((y - np.dot(X, coef_)) ** 2) return coef_, mse_ From 78946d0fd22b03732dd28396b4b676e4d5bc219b Mon Sep 17 00:00:00 2001 From: Antoine Baker Date: Thu, 16 Jan 2025 17:49:10 +0100 Subject: [PATCH 8/9] rename mse to sse --- sklearn/linear_model/_bayes.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index 85b494eeeb057..a38b6b46478af 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -299,8 +299,8 @@ def fit(self, X, y, sample_weight=None): # Convergence loop of the bayesian ridge regression for iter_ in range(self.max_iter): # update posterior mean coef_ based on alpha_ and lambda_ and - # compute corresponding mse - coef_, mse_ = self._update_coef_( + # compute corresponding sse (sum of squared errors) + coef_, sse_ = self._update_coef_( X, y, n_samples, n_features, XT_y, U, Vh, eigen_vals_, alpha_, lambda_ ) if self.compute_score: @@ -313,14 +313,14 @@ def fit(self, X, y, sample_weight=None): alpha_, lambda_, coef_, - mse_, + sse_, ) self.scores_.append(s) # Update alpha and lambda according to (MacKay, 1992) gamma_ = np.sum((alpha_ * eigen_vals_) / (lambda_ + alpha_ * eigen_vals_)) lambda_ = (gamma_ + 2 * lambda_1) / (np.sum(coef_**2) + 2 * lambda_2) - alpha_ = (sw_sum - gamma_ + 2 * alpha_1) / (mse_ + 2 * alpha_2) + alpha_ = (sw_sum - gamma_ + 2 * alpha_1) / (sse_ + 2 * alpha_2) # Check for convergence if iter_ != 0 and np.sum(np.abs(coef_old_ - coef_)) < self.tol: @@ -335,7 +335,7 @@ def fit(self, X, y, sample_weight=None): # log marginal likelihood and posterior covariance self.alpha_ = alpha_ self.lambda_ = lambda_ - self.coef_, mse_ = self._update_coef_( + self.coef_, sse_ = self._update_coef_( X, y, n_samples, n_features, XT_y, U, Vh, eigen_vals_, alpha_, lambda_ ) if self.compute_score: @@ -348,7 +348,7 @@ def fit(self, X, y, sample_weight=None): alpha_, lambda_, coef_, - mse_, + sse_, ) self.scores_.append(s) self.scores_ = np.array(self.scores_) @@ -396,7 +396,7 @@ def predict(self, X, return_std=False): def _update_coef_( self, X, y, n_samples, n_features, XT_y, U, Vh, eigen_vals_, alpha_, lambda_ ): - """Update posterior mean and compute corresponding mse. + """Update posterior mean and compute corresponding sse (sum of squared errors). Posterior mean is given by coef_ = scaled_sigma_ * X.T * y where scaled_sigma_ = (lambda_/alpha_ * np.eye(n_features) @@ -413,13 +413,13 @@ def _update_coef_( ) # Note: we do not need to explicit use the weights in this sum because - # y and X were proprocessed by _rescale_data to handle the weights. - mse_ = np.sum((y - np.dot(X, coef_)) ** 2) + # y and X were preprocessed by _rescale_data to handle the weights. + sse_ = np.sum((y - np.dot(X, coef_)) ** 2) - return coef_, mse_ + return coef_, sse_ def _log_marginal_likelihood( - self, n_samples, n_features, sw_sum, eigen_vals, alpha_, lambda_, coef, mse + self, n_samples, n_features, sw_sum, eigen_vals, alpha_, lambda_, coef, sse ): """Log marginal likelihood.""" alpha_1 = self.alpha_1 @@ -442,7 +442,7 @@ def _log_marginal_likelihood( score += 0.5 * ( n_features * log(lambda_) + sw_sum * log(alpha_) - - alpha_ * mse + - alpha_ * sse - lambda_ * np.sum(coef**2) + logdet_sigma - sw_sum * log(2 * np.pi) @@ -704,12 +704,12 @@ def update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_): coef_ = update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_) # Update alpha and lambda - mse_ = np.sum((y - np.dot(X, coef_)) ** 2) + sse_ = np.sum((y - np.dot(X, coef_)) ** 2) gamma_ = 1.0 - lambda_[keep_lambda] * np.diag(sigma_) lambda_[keep_lambda] = (gamma_ + 2.0 * lambda_1) / ( (coef_[keep_lambda]) ** 2 + 2.0 * lambda_2 ) - alpha_ = (n_samples - gamma_.sum() + 2.0 * alpha_1) / (mse_ + 2.0 * alpha_2) + alpha_ = (n_samples - gamma_.sum() + 2.0 * alpha_1) / (sse_ + 2.0 * alpha_2) # Prune the weights with a precision over a threshold keep_lambda = lambda_ < self.threshold_lambda @@ -724,7 +724,7 @@ def update_coeff(X, y, coef_, alpha_, keep_lambda, sigma_): + n_samples * log(alpha_) + np.sum(np.log(lambda_)) ) - s -= 0.5 * (alpha_ * mse_ + (lambda_ * coef_**2).sum()) + s -= 0.5 * (alpha_ * sse_ + (lambda_ * coef_**2).sum()) self.scores_.append(s) # Check for convergence From eb3ccc08682eb84985c5ac83d7ebf3a8071826c3 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Fri, 17 Jan 2025 11:16:48 +0500 Subject: [PATCH 9/9] Update sklearn/linear_model/_bayes.py --- sklearn/linear_model/_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index a38b6b46478af..27ce01d0e75d5 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -412,7 +412,7 @@ def _update_coef_( [X.T, U / (eigen_vals_ + lambda_ / alpha_)[None, :], U.T, y] ) - # Note: we do not need to explicit use the weights in this sum because + # Note: we do not need to explicitly use the weights in this sum because # y and X were preprocessed by _rescale_data to handle the weights. sse_ = np.sum((y - np.dot(X, coef_)) ** 2)