8000 [MRG+1] Refactoring plot_iris svm example. (#8279) · Sundrique/scikit-learn@78f43bc · GitHub
[go: up one dir, main page]

Skip to content

Commit 78f43bc

Browse files
lemonlaugSundrique
authored andcommitted
[MRG+1] Refactoring plot_iris svm example. (scikit-learn#8279)
1 parent b6c948c commit 78f43bc

File tree

1 file changed

+66
-39
lines changed

1 file changed

+66
-39
lines changed

examples/svm/plot_iris.py

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -39,55 +39,82 @@
3939
import matplotlib.pyplot as plt
4040
from sklearn import svm, datasets
4141

42+
43+
def make_meshgrid(x, y, h=.02):
44+
"""Create a mesh of points to plot in
45+
46+
Parameters
47+
----------
48+
x: data to base x-axis meshgrid on
49+
y: data to base y-axis meshgrid on
50+
h: stepsize for meshgrid, optional
51+
52+
Returns
53+
-------
54+
xx, yy : ndarray
55+
"""
56+
x_min, x_max = x.min() - 1, x.max() + 1
57+
y_min, y_max = y.min() - 1, y.max() + 1
58+
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
59+
np.arange(y_min, y_max, h))
60+
return xx, yy
61+
62+
63+
def plot_contours(ax, clf, xx, yy, **params):
64+
"""Plot the decision boundaries for a classifier.
65+
66+
Parameters
67+
----------
68+
ax: matplotlib axes object
69+
clf: a classifier
70+
xx: meshgrid ndarray
71+
yy: meshgrid ndarray
72+
params: dictionary of params to pass to contourf, optional
73+
"""
74+
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
75+
Z = Z.reshape(xx.shape)
76+
out = ax.contourf(xx, yy, Z, **params)
77+
return out
78+
79+
4280
# import some data to play with
4381
iris = datasets.load_iris()
44-
X = iris.data[:, :2] # we only take the first two features. We could
45-
# avoid this ugly slicing by using a two-dim dataset
82+
# Take the first two features. We could avoid this by using a two-dim dataset
83+
X = iris.data[:, :2]
4684
y = iris.target
4785

48-
h = .02 # step size in the mesh
49-
5086
# we create an instance of SVM and fit out data. We do not scale our
5187
# data since we want to plot the support vectors
5288
C = 1.0 # SVM regularization parameter
53-
svc = svm.SVC(kernel='linear', C=C).fit(X, y)
54-
rbf_svc = svm.SVC(kernel='rbf', gamma=0.7, C=C).fit(X, y)
55-
poly_svc = svm.SVC(kernel='poly', degree=3, C=C).fit(X, y)
56-
lin_svc = svm.LinearSVC(C=C).fit(X, y)
57-
58-
# create a mesh to plot in
59-
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
60-
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
61-
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
62-
np.arange(y_min, y_max, h))
89+
models = (svm.SVC(kernel='linear', C=C),
90+
svm.LinearSVC(C=C),
91+
svm.SVC(kernel='rbf', gamma=0.7, C=C),
92+
svm.SVC(kernel='poly', degree=3, C=C))
93+
models = (clf.fit(X, y) for clf in models)
6394

6495
# title for the plots
65-
titles = ['SVC with linear kernel',
96+
titles = ('SVC with linear kernel',
6697
'LinearSVC (linear kernel)',
6798
'SVC with RBF kernel',
68-
'SVC with polynomial (degree 3) kernel']
69-
70-
71-
for i, clf in enumerate((svc, lin_svc, rbf_svc, poly_svc)):
72-
# Plot the decision boundary. For that, we will assign a color to each
73-
# point in the mesh [x_min, x_max]x[y_min, y_max].
74-
plt.subplot(2, 2, i + 1)
75-
plt.subplots_adjust(wspace=0.4, hspace=0.4)
76-
77-
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
78-
79-
# Put the result into a color plot
80-
Z = Z.reshape(xx.shape)
81-
plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
82-
83-
# Plot also the training points
84-
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.coolwarm)
85-
plt.xlabel('Sepal length')
86-
plt.ylabel('Sepal width')
87-
plt.xlim(xx.min(), xx.max())
88-
plt.ylim(yy.min(), yy.max())
89-
plt.xticks(())
90-
plt.yticks(())
91-
plt.title(titles[i])
99+
'SVC with polynomial (degree 3) kernel')
100+
101+
# Set-up 2x2 grid for plotting.
102+
fig, sub = plt.subplots(2, 2)
103+
plt.subplots_adjust(wspace=0.4, hspace=0.4)
104+
105+
X0, X1 = X[:, 0], X[:, 1]
106+
xx, yy = make_meshgrid(X0, X1)
107+
108+
for clf, title, ax in zip(models, titles, sub.flatten()):
109+
plot_contours(ax, clf, xx, yy,
110+
cmap=plt.cm.coolwarm, alpha=0.8)
111+
ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
112+
ax.set_xlim(xx.min(), xx.max())
113+
ax.set_ylim(yy.min(), yy.max())
114+
ax.set_xlabel('Sepal length')
115+
ax.set_ylabel('Sepal width')
116+
ax.set_xticks(())
117+
ax.set_yticks(())
118+
ax.set_title(title)
92119

93120
plt.show()

0 commit comments

Comments
 (0)
0