8000 DOC: use notebook-style for plot_ols_3d (#22547) · scikit-learn/scikit-learn@391baef · GitHub
[go: up one dir, main page]

Skip to content

Commit 391baef

Browse files
SamAdamDaylesteve
andauthored
DOC: use notebook-style for plot_ols_3d (#22547)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent 691972a commit 391baef

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

examples/linear_model/plot_ols_3d.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,18 @@
77
Features 1 and 2 of the diabetes-dataset are fitted and
88
plotted below. It illustrates that although feature 2
99
has a strong coefficient on the full model, it does not
10-
give us much regarding `y` when compared to just feature 1
11-
10+
give us much regarding `y` when compared to just feature 1.
1211
"""
1312

1413
# Code source: Gaël Varoquaux
1514
# Modified for documentation by Jaques Grobler
1615
# License: BSD 3 clause
1716

18-
import matplotlib.pyplot as plt
19-
import numpy as np
20-
from mpl_toolkits.mplot3d import Axes3D
17+
# %%
18+
# First we load the diabetes dataset.
2119

22-
from sklearn import datasets, linear_model
20+
from sklearn import datasets
21+
import numpy as np
2322

2423
X, y = datasets.load_diabetes(return_X_y=True)
2524
indices = (0, 1)
@@ -29,16 +28,25 @@
2928
y_train = y[:-20]
3029
y_test = y[-20:]
3130

31+
# %%
32+
# Next we fit a linear regression model.
33+
34+
from sklearn import linear_model
35+
3236
ols = linear_model.LinearRegression()
33-
ols.fit(X_train, y_train)
37+
_ = ols.fit(X_train, y_train)
38+
39+
40+
# %%
41+
# Finally we plot the figure from three different views.
42+
43+
import matplotlib.pyplot as plt
3444

3545

36-
# #############################################################################
37-
# Plot the figure
3846
def plot_figs(fig_num, elev, azim, X_train, clf):
3947
fig = plt.figure(fig_num, figsize=(4, 3))
4048
plt.clf()
41-
ax = Axes3D(fig, elev=elev, azim=azim)
49+
ax = fig.add_subplot(111, projection="3d", elev=elev, azim=azim)
4250

4351
ax.scatter(X_train[:, 0], X_train[:, 1], y_train, c="k", marker="+")
4452
ax.plot_surface(

0 commit comments

Comments
 (0)
0