|
39 | 39 | import matplotlib.pyplot as plt
|
40 | 40 | from sklearn import svm, datasets
|
41 | 41 |
|
| 42 | + |
| 43 | +def make_meshgrid(x, y, h=.02): |
| 44 | + """Create a mesh of points to plot in |
| 45 | +
|
| 46 | + Parameters |
| 47 | + ---------- |
| 48 | + x: data to base x-axis meshgrid on |
| 49 | + y: data to base y-axis meshgrid on |
| 50 | + h: stepsize for meshgrid, optional |
| 51 | +
|
| 52 | + Returns |
| 53 | + ------- |
| 54 | + xx, yy : ndarray |
| 55 | + """ |
| 56 | + x_min, x_max = x.min() - 1, x.max() + 1 |
| 57 | + y_min, y_max = y.min() - 1, y.max() + 1 |
| 58 | + xx, yy = np.meshgrid(np.arange(x_min, x_max, h), |
| 59 | + np.arange(y_min, y_max, h)) |
| 60 | + return xx, yy |
| 61 | + |
| 62 | + |
| 63 | +def plot_contours(ax, clf, xx, yy, **params): |
| 64 | + """Plot the decision boundaries for a classifier. |
| 65 | +
|
| 66 | + Parameters |
| 67 | + ---------- |
| 68 | + ax: matplotlib axes object |
| 69 | + clf: a classifier |
| 70 | + xx: meshgrid ndarray |
| 71 | + yy: meshgrid ndarray |
| 72 | + params: dictionary of params to pass to contourf, optional |
| 73 | + """ |
| 74 | + Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) |
| 75 | + Z = Z.reshape(xx.shape) |
| 76 | + out = ax.contourf(xx, yy, Z, **params) |
| 77 | + return out |
| 78 | + |
| 79 | + |
42 | 80 | # import some data to play with
|
43 | 81 | iris = datasets.load_iris()
|
44 |
| -X = iris.data[:, :2] # we only take the first two features. We could |
45 |
| - # avoid this ugly slicing by using a two-dim dataset |
| 82 | +# Take the first two features. We could avoid this by using a two-dim dataset |
| 83 | +X = iris.data[:, :2] |
46 | 84 | y = iris.target
|
47 | 85 |
|
48 |
| -h = .02 # step size in the mesh |
49 |
| - |
50 | 86 | # we create an instance of SVM and fit out data. We do not scale our
|
51 | 87 | # data since we want to plot the support vectors
|
52 | 88 | C = 1.0 # SVM regularization parameter
|
53 |
| -svc = svm.SVC(kernel='linear', C=C).fit(X, y) |
54 |
| -rbf_svc = svm.SVC(kernel='rbf', gamma=0.7, C=C).fit(X, y) |
55 |
| -poly_svc = svm.SVC(kernel='poly', degree=3, C=C).fit(X, y) |
56 |
| -lin_svc = svm.LinearSVC(C=C).fit(X, y) |
57 |
| - |
58 |
| -# create a mesh to plot in |
59 |
| -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 |
60 |
| -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 |
61 |
| -xx, yy = np.meshgrid(np.arange(x_min, x_max, h), |
62 |
| - np.arange(y_min, y_max, h)) |
| 89 | +models = (svm.SVC(kernel='linear', C=C), |
| 90 | + svm.LinearSVC(C=C), |
| 91 | + svm.SVC(kernel='rbf', gamma=0.7, C=C), |
| 92 | + svm.SVC(kernel='poly', degree=3, C=C)) |
| 93 | +models = (clf.fit(X, y) for clf in models) |
63 | 94 |
|
64 | 95 | # title for the plots
|
65 |
| -titles = ['SVC with linear kernel', |
| 96 | +titles = ('SVC with linear kernel', |
66 | 97 | 'LinearSVC (linear kernel)',
|
67 | 98 | 'SVC with RBF kernel',
|
68 |
| - 'SVC with polynomial (degree 3) kernel'] |
69 |
| - |
70 |
| - |
71 |
| -for i, clf in enumerate((svc, lin_svc, rbf_svc, poly_svc)): |
72 |
| - # Plot the decision boundary. For that, we will assign a color to each |
73 |
| - # point in the mesh [x_min, x_max]x[y_min, y_max]. |
74 |
| - plt.subplot(2, 2, i + 1) |
75 |
| - plt.subplots_adjust(wspace=0.4, hspace=0.4) |
76 |
| - |
77 |
| - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) |
78 |
| - |
79 |
| - # Put the result into a color plot |
80 |
| - Z = Z.reshape(xx.shape) |
81 |
| - plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8) |
82 |
| - |
83 |
| - # Plot also the training points |
84 |
| - plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.coolwarm) |
85 |
| - plt.xlabel('Sepal length') |
86 |
| - plt.ylabel('Sepal width') |
87 |
| - plt.xlim(xx.min(), xx.max()) |
88 |
| - plt.ylim(yy.min(), yy.max()) |
89 |
| - plt.xticks(()) |
90 |
| - plt.yticks(()) |
91 |
| - plt.title(titles[i]) |
| 99 | + 'SVC with polynomial (degree 3) kernel') |
| 100 | + |
| 101 | +# Set-up 2x2 grid for plotting. |
| 102 | +fig, sub = plt.subplots(2, 2) |
| 103 | +plt.subplots_adjust(wspace=0.4, hspace=0.4) |
| 104 | + |
| 105 | +X0, X1 = X[:, 0], X[:, 1] |
| 106 | +xx, yy = make_meshgrid(X0, X1) |
| 107 | + |
| 108 | +for clf, title, ax in zip(models, titles, sub.flatten()): |
| 109 | + plot_contours(ax, clf, xx, yy, |
| 110 | + cmap=plt.cm.coolwarm, alpha=0.8) |
| 111 | + ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k') |
| 112 | + ax.set_xlim(xx.min(), xx.max()) |
| 113 | + ax.set_ylim(yy.min(), yy.max()) |
| 114 | + ax.set_xlabel('Sepal length') |
| 115 | + ax.set_ylabel('Sepal width') |
| 116 | + ax.set_xticks(()) |
| 117 | + ax.set_yticks(()) |
| 118 | + ax.set_title(title) |
92 | 119 |
|
93 | 120 | plt.show()
|
0 commit comments