8000 [MRG+2] Adding return_std options for models in linear_model/bayes.py by sergeyf · Pull Request #7838 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+2] Adding return_std options for models in linear_model/bayes.py #7838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Dec 1, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c19e2c9
initial commit for return_std
sergeyf Nov 7, 2016
2fad2d5
initial commit for return_std
sergeyf Nov 7, 2016
a6c0bf3
adding tests, examples, ARD predict_std
sergeyf Nov 7, 2016
f92a860
adding tests, examples, ARD predict_std
sergeyf Nov 7, 2016
4bae33d
a smidge more documentation
sergeyf Nov 7, 2016
ea9fad4
a smidge more documentation
sergeyf Nov 7, 2016
25c457e
Missed a few PEP8 issues
sergeyf Nov 7, 2016
b905a23
Changing predict_std to return_std #1
sergeyf Nov 7, 2016
0a3ccd2
Changing predict_std to return_std #2
sergeyf Nov 7, 2016
5634ee2
Changing predict_std to return_std #3
sergeyf Nov 7, 2016
e817de3
Changing predict_std to return_std final
sergeyf Nov 7, 2016
806818a
adding better plots via polynomial regression
sergeyf Nov 8, 2016
2f0bd32
trying to fix flake error
sergeyf Nov 8, 2016
21ba9d5
fix to ARD plotting issue
sergeyf Nov 8, 2016
df3038a
fixing some flakes
sergeyf Nov 8, 2016
5d9739d
Two blank lines part 1
sergeyf Nov 8, 2016
542de0b
Two blank lines part 2
sergeyf Nov 8, 2016
a552022
More newlines!
sergeyf Nov 8, 2016
b9c55df
Even more newlines
sergeyf Nov 8, 2016
b9f7319
adding info to the doc string for the two plot files
sergeyf Nov 10, 2016
0cd9f5c
Rephrasing "polynomial" for Bayesian Ridge Regression
sergeyf Nov 16, 2016
8eaa4c7
Updating "polynomia" for ARD
sergeyf Nov 16, 2016
ba1c2c6
Adding more formal references
sergeyf Nov 16, 2016
3599b57
Another asked-for improvement to doc string.
sergeyf Nov 16, 2016
6a615f1
Fixing flake8 errors
sergeyf Nov 16, 2016
0ded8b7
Cleaning up the tests a smidge.
sergeyf Nov 20, 2016
1e1392c
A few more flakes
sergeyf Nov 21, 2016
f7e31f1
requested fixes from Andy
sergeyf Nov 24, 2016
039ae83
Mini bug fix
sergeyf Nov 24, 2016
092a569
Final pep8 fix
sergeyf Nov 30, 2016
561ef01
pep8 fix round 2
sergeyf Nov 30, 2016
5bb4080
Fix beta_ to alpha_ in the comments
sergeyf Dec 1, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8000
36 changes: 34 additions & 2 deletions examples/linear_model/plot_ard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@

The estimation of the model is done by iteratively maximizing the
marginal log-likelihood of the observations.

We also plot predictions and uncertainties for ARD
for one dimensional regression using polynomial feature expansion.
Note the uncertainty starts going up on the right side of the plot.
This is because these test samples are outside of the range of the training
samples.
"""
print(__doc__)

Expand Down Expand Up @@ -54,8 +60,8 @@
ols.fit(X, y)

###############################################################################
# Plot the true weights, the estimated weights and the histogram of the
# weights
# Plot the true weights, the estimated weights, the histogram of the
# weights, and predictions with standard deviations
plt.figure(figsize=(6, 5))
plt.title("Weights of the model")
plt.plot(clf.coef_, color='darkblue', linestyle='-', linewidth=2,
Expand All @@ -81,4 +87,30 @@
plt.plot(clf.scores_, color='navy', linewidth=2)
plt.ylabel("Score")
plt.xlabel("Iterations")


# Plotting some predictions for polynomial regression
def f(x, noise_amount):
y = np.sqrt(x) * np.sin(x)
noise = np.random.normal(0, 1, len(x))
return y + noise_amount * noise


degree = 10
X = np.linspace(0, 10, 100)
y = f(X, noise_amount=1)
clf_poly = ARDRegression(threshold_lambda=1e5)
clf_poly.fit(np.vander(X, degree), y)

X_plot = np.linspace(0, 11, 25)
y_plot = f(X_plot, noise_amount=0)
y_mean, y_std = clf_poly.predict(np.vander(X_plot, degree), return_std=True)
plt.figure(figsize=(6, 5))
plt.errorbar(X_plot, y_mean, y_std, color='navy',
label="Polynomial ARD", linewidth=2)
plt.plot(X_plot, y_plot, color='gold', linewidth=2,
label="Ground Truth")
plt.ylabel("Output y")
plt.xlabel("Feature X")
plt.legend(loc="lower left")
plt.show()
35 changes: 34 additions & 1 deletion examples/linear_model/plot_bayesian_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@

The estimation of the model is done by iteratively maximizing the
marginal log-likelihood of the observations.

We also plot predictions and uncertainties for Bayesian Ridge Regression
for one dimensional regression using polynomial feature expansion.
Note the uncertainty starts going up on the right side of the plot.
This is because these test samples are outside of the range of the training
samples.
"""
print(__doc__)

Expand Down Expand Up @@ -51,7 +57,8 @@
ols.fit(X, y)

###############################################################################
# Plot true weights, estimated weights and histogram of the weights
# Plot true weights, estimated weights, histogram of the weights, and
# predictions with standard deviations
lw = 2
plt.figure(figsize=(6, 5))
plt.title("Weights of the model")
Expand All @@ -77,4 +84,30 @@
plt.plot(clf.scores_, color='navy', linewidth=lw)
plt.ylabel("Score")
plt.xlabel("Iterations")


# Plotting some predictions for polynomial regression
def f(x, noise_amount):
y = np.sqrt(x) * np.sin(x)
noise = np.random.normal(0, 1, len(x))
return y + noise_amount * noise


degree = 10
X = np.linspace(0, 10, 100)
y = f(X, noise_amount=0.1)
clf_poly = BayesianRidge()
clf_poly.fit(np.vander(X, degree), y)

X_plot = np.linspace(0, 11, 25)
y_plot = f(X_plot, noise_amount=0)
y_mean, y_std = clf_poly.predict(np.vander(X_plot, degree), return_std=True)
plt.figure(figsize=(6, 5))
plt.errorbar(X_plot, y_mean, y_std, color='navy',
label="Polynomial Bayesian Ridge Regression", linewidth=lw)
plt.plot(X_plot, y_plot, color='gold', linewidth=lw,
label="Ground Truth")
plt.ylabel("Output y")
plt.xlabel("Feature X")
plt.legend(loc="lower left")
plt.show()
109 changes: 103 additions & 6 deletions sklearn/linear_model/bayes.py

Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class BayesianRidge(LinearModel, RegressorMixin):
lambda_ : float
estimated precision of the weights.

sigma_ : array, shape = (n_features, n_features)
estimated variance-covariance matrix of the weights

scores_ : float
if computed, value of the objective function (to be maximized)

Expand All @@ -109,6 +112,16 @@ class BayesianRidge(LinearModel, RegressorMixin):
Notes
-----
See examples/linear_model/plot_bayesian_ridge.py for an example.

References
----------
D. J. C. MacKay, Bayesian Interpolation, Computation and Neural Systems,
Vol. 4, No. 3, 1992.

R. Salakhutdinov, Lecture notes on Statistical Machine Learning,
http://www.utstat.toronto.edu/~rsalakhu/sta4273/notes/Lecture2.pdf#page=15
Their beta is our self.alpha_
Their alpha is our self.lambda_
"""

def __init__(self, n_iter=300, tol=1.e-3, alpha_1=1.e-6, alpha_2=1.e-6,
Expand Down Expand Up @@ -142,8 +155,10 @@ def fit(self, X, y):
self : returns an instance of self.
"""
X, y = check_X_y(X, y, dtype=np.float64, y_numeric=True)
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
X, y, X_offset_, y_offset_, X_scale_ = self._preprocess_data(
X, y, self.fit_intercept, self.normalize, self.copy_X)
self.X_offset_ = X_offset_
self.X_scale_ = X_scale_
n_samples, n_features = X.shape

# Initialization of the values of the parameters
Expand Down Expand Up @@ -171,7 +186,8 @@ def fit(self, X, y):
# coef_ = sigma_^-1 * XT * y
if n_samples > n_features:
coef_ = np.dot(Vh.T,
Vh / (eigen_vals_ + lambda_ / alpha_)[:, None])
Vh / (eigen_vals_ +
lambda_ / alpha_)[:, np.newaxis])
coef_ = np.dot(coef_, XT_y)
if self.compute_score:
logdet_sigma_ = - np.sum(
Expand Down Expand Up @@ -216,10 +232,45 @@ def fit(self, X, y):
self.alpha_ = alpha_
self.lambda_ = lambda_
self.coef_ = coef_
sigma_ = np.dot(Vh.T,
Vh / (eigen_vals_ + lambda_ / alpha_)[:, np.newaxis])
self.sigma_ = (1. / alpha_) * sigma_

self._set_intercept(X_offset, y_offset, X_scale)
self._set_intercept(X_offset_, y_offset_, X_scale_)
return self

def predict(self, X, return_std=False):
"""Predict using the linear model.

In addition to the mean of the predictive distribution, also its
standard deviation can be returned.

Parameters
----------
X : {array-like, sparse matrix}, shape = (n_samples, n_features)
Samples.
return_std : boolean, optional
Whether to return the standard deviation of posterior prediction.

Returns
-------
y_mean : array, shape = (n_samples,)
Mean of predictive distribution of query points.

y_std : array, shape = (n_samples,)
Standard deviation of predictive distribution of query points.
"""
y_mean = self._decision_function(X)
if return_std is False:
return y_mean
else:
if self.normalize:
X = (X - self.X_offset_) / self.X_scale_
sigmas_squared_data = (np.dot(X, self.sigma_) * X).sum(axis=1)
y_std = np.sqrt(sigmas_squared_data + (1. / self.alpha_))
return y_mean, y_std


###############################################################################
# ARD (Automatic Relevance Determination) regression
Expand Down Expand Up @@ -323,6 +374,19 @@ class ARDRegression(LinearModel, RegressorMixin):
Notes
--------
See examples/linear_model/plot_ard.py for an example.

References
----------
D. J. C. MacKay, Bayesian nonlinear modeling for the prediction
competition, ASHRAE Transactions, 1994.

R. Salakhutdinov, Lecture notes on Statistical Machine Learning,
http://www.utstat.toronto.edu/~rsalakhu/sta4273/notes/Lecture2.pdf#page=15
Their beta is our self.alpha_
Their alpha is our self.lambda_
ARD is a little different than the slide: only dimensions/features for
which self.lambda_ < self.threshold_lambda are kept and the rest are
discarded.
"""

def __init__(self, n_iter=300, tol=1.e-3, alpha_1=1.e-6, alpha_2=1.e-6,
Expand Down Expand Up @@ -365,7 +429,7 @@ def fit(self, X, y):
n_samples, n_features = X.shape
coef_ = np.zeros(n_features)

X, y, X_offset, y_offset, X_scale = self._preprocess_data(
X, y, X_offset_, y_offset_, X_scale_ = self._preprocess_data(
X, y, self.fit_intercept, self.normalize, self.copy_X)

# Launch the convergence loop
Expand Down Expand Up @@ -417,7 +481,7 @@ def fit(self, X, y):
s = (lambda_1 * np.log(lambda_) - lambda_2 * lambda_).sum()
s += alpha_1 * log(alpha_) - alpha_2 * alpha_
s += 0.5 * (fast_logdet(sigma_) + n_samples * log(alpha_) +
np.sum(np.log(lambda_)))
np.sum(np.log(lambda_)))
s -= 0.5 * (alpha_ * rmse_ + (lambda_ * coef_ ** 2).sum())
self.scores_.append(s)

Expand All @@ -432,5 +496,38 @@ def fit(self, X, y):
self.alpha_ = alpha_
self.sigma_ = sigma_
self.lambda_ = lambda_
self._set_intercept(X_offset, y_offset, X_scale)
self._set_intercept(X_offset_, y_offset_, X_scale_)
return self

def predict(self, X, return_std=False):
"""Predict using the linear model.

In addition to the mean of the predictive distribution, also its
standard deviation can be returned.

Parameters
----------
X : {array-like, sparse matrix}, shape = (n_samples, n_features)
Samples.

return_std : boolean, optional
Whether to return the standard deviation of posterior prediction.

Returns
-------
y_mean : array, shape = (n_samples,)
Mean of predictive distribution of query points.

y_std : array, shape = (n_samples,)
Standard deviation of predictive distribution of query points.
"""
y_mean = self._decision_function(X)
if return_std is False:
return y_mean
else:
if self.normalize:
X = (X - self.X_offset_) / self.X_scale_
X = X[:, self.lambda_ < self.threshold_lambda]
sigmas_squared_data = (np.dot(X, self.sigma_) * X).sum(axis=1)
y_std = np.sqrt(sigmas_squared_data + (1. / self.alpha_))
return y_mean, y_std
32 changes: 32 additions & 0 deletions sklearn/linear_model/tests/test_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,35 @@ def test_toy_ard_object():
# Check that the model could approximately learn the identity function
test = [[1], [3], [4]]
assert_array_almost_equal(clf.predict(test), [1, 3, 4], 2)


def test_return_std():
# Test return_std option for both Bayesian regressors
def f(X):
return np.dot(X, w) + b

def f_noise(X, noise_mult):
return f(X) + np.random.randn(X.shape[0])*noise_mult

d = 5
n_train = 50
n_test = 10

w = np.array([1.0, 0.0, 1.0, -1.0, 0.0])
b = 1.0

X = np.random.random((n_train, d))
X_test = np.random.random((n_test, d))

for decimal, noise_mult in enumerate([1, 0.1, 0.01]):
y = f_noise(X, noise_mult)

m1 = BayesianRidge()
m1.fit(X, y)
y_mean1, y_std1 = m1.predict(X_test, return_std=True)
assert_array_almost_equal(y_std1, noise_mult, decimal=decimal)

m2 = ARDRegression()
m2.fit(X, y)
y_mean2, y_std2 = m2.predict(X_test, return_std=True)
assert_array_almost_equal(y_std2, noise_mult, decimal=decimal)
0