10000 DOC Enhanced example visualization to RFE (#28862) · charlesjhill/scikit-learn@79e14a7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 79e14a7

Browse files
authored
DOC Enhanced example visualization to RFE (scikit-learn#28862)
1 parent 70ca21f commit 79e14a7

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

examples/feature_selection/plot_rfe_digits.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
Recursive feature elimination
44
=============================
55
6-
A recursive feature elimination example showing the relevance of pixels in
7-
a digit classification task.
6+
This example demonstrates how Recursive Feature Elimination
7+
(:class:`~sklearn.feature_selection.RFE`) can be used to determine the
8+
importance of individual pixels for classifying handwritten digits.
9+
:class:`~sklearn.feature_selection.RFE` recursively removes the least
10+
significant features, assigning ranks based on their importance, where higher
11+
`ranking_` values denote lower importance. The ranking is visualized using both
12+
shades of blue and pixel annotations for clarity. As expected, pixels positioned
13+
at the center of the image tend to be more predictive than those near the edges.
814
915
.. note::
1016
@@ -16,21 +22,33 @@
1622

1723
from sklearn.datasets import load_digits
1824
from sklearn.feature_selection import RFE
19-
from sklearn.svm import SVC
25+
from sklearn.linear_model import LogisticRegression
26+
from sklearn.pipeline import Pipeline
27+
from sklearn.preprocessing import MinMaxScaler
2028

2129
# Load the digits dataset
2230
digits = load_digits()
2331
X = digits.images.reshape((len(digits.images), -1))
2432
y = digits.target
2533

26-
# Create the RFE object and rank each pixel
27-
svc = SVC(kernel="linear", C=1)
28-
rfe = RFE(estimator=svc, n_features_to_select=1, step=1)
29-
rfe.fit(X, y)
30-
ranking = rfe.ranking_.reshape(digits.images[0].shape)
34+
pipe = Pipeline(
35+
[
36+
("scaler", MinMaxScaler()),
37+
("rfe", RFE(estimator=LogisticRegression(), n_features_to_select=1, step=1)),
38+
]
39+
)
40+
41+
pipe.fit(X, y)
42+
ranking = pipe.named_steps["rfe"].ranking_.reshape(digits.images[0].shape)
3143

3244
# Plot pixel ranking
3345
plt.matshow(ranking, cmap=plt.cm.Blues)
46+
47+
# Add annotations for pixel numbers
48+
for i in range(ranking.shape[0]):
49+
for j in range(ranking.shape[1]):
50+
plt.text(j, i, str(ranking[i, j]), ha="center", va="center", color="black")
51+
3452
plt.colorbar()
35-
plt.title("Ranking of pixels with RFE")
53+
plt.title("Ranking of pixels with RFE\n(Logistic Regression)")
3654
plt.show()

0 commit comments

Comments
 (0)
0