|
3 | 3 | Recursive feature elimination
|
4 | 4 | =============================
|
5 | 5 |
|
6 |
| -A recursive feature elimination example showing the relevance of pixels in |
7 |
| -a digit classification task. |
| 6 | +This example demonstrates how Recursive Feature Elimination |
| 7 | +(:class:`~sklearn.feature_selection.RFE`) can be used to determine the |
| 8 | +importance of individual pixels for classifying handwritten digits. |
| 9 | +:class:`~sklearn.feature_selection.RFE` recursively removes the least |
| 10 | +significant features, assigning ranks based on their importance, where higher |
| 11 | +`ranking_` values denote lower importance. The ranking is visualized using both |
| 12 | +shades of blue and pixel annotations for clarity. As expected, pixels positioned |
| 13 | +at the center of the image tend to be more predictive than those near the edges. |
8 | 14 |
|
9 | 15 | .. note::
|
10 | 16 |
|
|
16 | 22 |
|
17 | 23 | from sklearn.datasets import load_digits
|
18 | 24 | from sklearn.feature_selection import RFE
|
19 |
| -from sklearn.svm import SVC |
| 25 | +from sklearn.linear_model import LogisticRegression |
| 26 | +from sklearn.pipeline import Pipeline |
| 27 | +from sklearn.preprocessing import MinMaxScaler |
20 | 28 |
|
21 | 29 | # Load the digits dataset
|
22 | 30 | digits = load_digits()
|
23 | 31 | X = digits.images.reshape((len(digits.images), -1))
|
24 | 32 | y = digits.target
|
25 | 33 |
|
26 |
| -# Create the RFE object and rank each pixel |
27 |
| -svc = SVC(kernel="linear", C=1) |
28 |
| -rfe = RFE(estimator=svc, n_features_to_select=1, step=1) |
29 |
| -rfe.fit(X, y) |
30 |
| -ranking = rfe.ranking_.reshape(digits.images[0].shape) |
| 34 | +pipe = Pipeline( |
| 35 | + [ |
| 36 | + ("scaler", MinMaxScaler()), |
| 37 | + ("rfe", RFE(estimator=LogisticRegression(), n_features_to_select=1, step=1)), |
| 38 | + ] |
| 39 | +) |
| 40 | + |
| 41 | +pipe.fit(X, y) |
| 42 | +ranking = pipe.named_steps["rfe"].ranking_.reshape(digits.images[0].shape) |
31 | 43 |
|
32 | 44 | # Plot pixel ranking
|
33 | 45 | plt.matshow(ranking, cmap=plt.cm.Blues)
|
| 46 | + |
| 47 | +# Add annotations for pixel numbers |
| 48 | +for i in range(ranking.shape[0]): |
| 49 | + for j in range(ranking.shape[1]): |
| 50 | + plt.text(j, i, str(ranking[i, j]), ha="center", va="center", color="black") |
| 51 | + |
34 | 52 | plt.colorbar()
|
35 |
| -plt.title("Ranking of pixels with RFE") |
| 53 | +plt.title("Ranking of pixels with RFE\n(Logistic Regression)") |
36 | 54 | plt.show()
|
0 commit comments