8000 DOC Fix features selection example (#12748) · qdeffense/scikit-learn@9014a6f · GitHub
[go: up one dir, main page]

Skip to content

Commit 9014a6f

Browse files
pierretallotteglemaitre
authored andcommitted
DOC Fix features selection example (scikit-learn#12748)
1 parent ec2ea1b commit 9014a6f

File tree

1 file changed

+34
-19
lines changed

1 file changed

+34
-19
lines changed

examples/feature_selection/plot_feature_selection.py

Lines changed: 34 additions & 19 deletions
< 10000 tr class="diff-line-row">
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
2-
===============================
2+
============================
33
Univariate Feature Selection
4-
===============================
4+
============================
55
66
An example showing univariate feature selection.
77
@@ -24,21 +24,29 @@
2424
import numpy as np
2525
import matplotlib.pyplot as plt
2626

27-
from sklearn import datasets, svm
28-
from sklearn.feature_selection import SelectPercentile, f_classif
27+
from sklearn.datasets import load_iris
28+
from sklearn.model_selection import train_test_split
29+
from sklearn.preprocessing import MinMaxScaler
30+
from sklearn.svm import LinearSVC
31+
from sklearn.pipeline import make_pipeline
32+
from sklearn.feature_selection import SelectKBest, f_classif
2933

3034
# #############################################################################
3135
# Import some data to play with
3236

3337
# The iris dataset
34-
iris = datasets.load_iris()
38+
X, y = load_iris(return_X_y=True)
3539

3640
# Some noisy data not correlated
37-
E = np.random.uniform(0, 0.1, size=(len(iris.data), 20))
41+
E = np.random.RandomState(42).uniform(0, 0.1, size=(X.shape[0], 20))
3842

3943
# Add the noisy data to the informative features
40-
X = np.hstack((iris.data, E))
41-
y = iris.target
44+
X = np.hstack((X, E))
45+
46+
# Split dataset to select feature and evaluate the classifier
47+
X_train, X_test, y_train, y_test = train_test_split(
48+
X, y, stratify=y, random_state=0
49+
)
4250

4351
plt.figure(1)
4452
plt.clf()
@@ -47,9 +55,10 @@
4755

4856
# #############################################################################
4957
# Univariate feature selection with F-test for feature scoring
50-
# We use the default selection function: the 10% most significant features
51-
selector = SelectPercentile(f_classif, percentile=10)
52-
selector.fit(X, y)
58+
# We use the default selection function to select the four
59+
# most significant features
60+
selector = SelectKBest(f_classif, k=4)
61+
selector.fit(X_train, y_train)
5362
scores = -np.log10(selector.pvalues_)
5463
scores /= scores.max()
5564
plt.bar(X_indices - .45, scores, width=.2,
@@ -58,20 +67,26 @@
5867

5968
# #############################################################################
6069
# Compare to the weights of an SVM
61-
clf = svm.SVC(kernel='linear')
62-
clf.fit(X, y)
70+
clf = make_pipeline(MinMaxScaler(), LinearSVC())
71+
clf.fit(X_train, y_train)
72+
print('Classification accuracy without selecting features: {:.3f}'
73+
.format(clf.score(X_test, y_test)))
6374

64-
svm_weights = (clf.coef_ ** 2).sum(axis=0)
65-
svm_weights /= svm_weights.max()
75+
svm_weights = np.abs(clf[-1].coef_).sum(axis=0)
76+
svm_weights /= svm_weights.sum()
6677

6778
plt.bar(X_indices - .25, svm_weights, width=.2, label='SVM weight',
6879
color='navy', edgecolor='black')
6980

70-
clf_selected = svm.SVC(kernel='linear')
71-
clf_selected.fit(selector.transform(X), y)
81+
clf_selected = make_pipeline(
82+
SelectKBest(f_classif, k=4), MinMaxScaler(), LinearSVC()
83+
)
84+
clf_selected.fit(X_train, y_train)
85+
print('Classification accuracy after univariate feature selection: {:.3f}'
86+
.format(clf_selected.score(X_test, y_test)))
7287

73-
svm_weights_selected = (clf_selected.coef_ ** 2).sum(axis=0)
74-
svm_weights_selected /= svm_weights_selected.max()
88+
svm_weights_selected = np.abs(clf_selected[-1].coef_).sum(axis=0)
89+
svm_weights_selected /= svm_weights_selected.sum()
7590

7691
plt.bar(X_indices[selector.get_support()] - .05, svm_weights_selected,
7792
width=.2, label='SVM weights after selection', color='c',

0 commit comments

Comments
 (0)
0