10000 Update doc example on permutation and MDI importance · scikit-learn/scikit-learn@229cc4d · GitHub
[go: up one dir, main page]

Skip to content

Commit 229cc4d

Browse files
committed
Update doc example on permutation and MDI importance
1 parent 0b48af4 commit 229cc4d

File tree

1 file changed

+51
-7
lines changed

1 file changed

+51
-7
lines changed

examples/inspection/plot_permutation_importance.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,20 @@
1616
variable, as long as the model has the capacity to use them to overfit.
1717
1818
This example shows how to use Permutation Importances as an alternative that
19-
can mitigate those limitations.
19+
can mitigate those limitations. It also introduces a new method from recent
20+
litterature on random forests that allows removing the aforementioned biases
21+
from MDI while keeping its computational efficiency.
2022
2123
.. rubric:: References
2224
2325
* :doi:`L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32,
2426
2001. <10.1023/A:1010933404324>`
27+
* :doi:`Zhou, Z., & Hooker, G., "Unbiased measurement of feature importance in
28+
tree-based methods". ACM Transactions on Knowledge Discovery from Data, 15(2),
29+
Article 26, 2020. <10.1145/3429445>`
30+
* :doi:`Li, X., Wang, Y., Basu, S., Kumbier, K., & Yu, B., "A debiased MDI
31+
feature importance measure for random forests". Proceedings of the 33rd Conference on
32+
Neural Information Processing Systems (NeurIPS 2019). <10.48550/arXiv.1906.10845>`
2533
2634
"""
2735

@@ -87,7 +95,7 @@
8795
rf = Pipeline(
8896
[
8997
("preprocess", preprocessing),
90-
("classifier", RandomForestClassifier(random_state=42)),
98+
("classifier", RandomForestClassifier(random_state=42, oob_score=True)),
9199
]
92100
)
93101
rf.fit(X_train, y_train)
@@ -98,9 +106,16 @@
98106
# Before inspecting the feature importances, it is important to check that
99107
# the model predictive performance is high enough. Indeed, there would be little
100108
# interest in inspecting the important features of a non-predictive model.
109+
#
110+
# By default, random forests subsample a part of the dataset to train each tree, a
111+
# procedure known as bagging, leaving aside "out-of-bag" (oob) samples.
112+
# These samples can be leveraged to compute an accuracy score independantly of the
113+
# training samples, when setting the parameter `oob_score = True`.
114+
# This score should be close to the test score.
101115

102116
print(f"RF train accuracy: {rf.score(X_train, y_train):.3f}")
103117
print(f"RF test accuracy: {rf.score(X_test, y_test):.3f}")
118+
print(f"RF out-of-bag accuracy: {rf[-1].oob_score_:.3f}")
104119

105120
# %%
106121
# Here, one can observe that the train accuracy is very high (the forest model
@@ -140,17 +155,27 @@
140155
#
141156
# The fact that we use training set statistics explains why both the
142157
# `random_num` and `random_cat` features have a non-null importance.
158+
#
159+
# The attribute `ufi_feature_importances_`, available as soon as `oob_score` is set to
160+
# `True`, uses the out-of-bag samples of each tree to correct these biases.
161+
# It succesfully detects the uninformative features by assigning them a near zero
162+
# (here slightly negative) importance value.
163+
# The prefix `ufi` refers to the name given by the authors to their method. An other
164+
# method is available with the attribute `mdi_oob_feature_importances_`. See references
165+
# for more details on these methods.
143166
import pandas as pd
144167

145168
feature_names = rf[:-1].get_feature_names_out()
146169

147-
mdi_importances = pd.Series(
148-
rf[-1].feature_importances_, index=feature_names
149-
).sort_values(ascending=True)
170+
mdi_importances = pd.DataFrame(index=feature_names)
171+
mdi_importances.loc[:, "unbiased mdi"] = rf[-1].ufi_feature_importances_
172+
mdi_importances.loc[:, "mdi"] = rf[-1].feature_importances_
173+
mdi_importances = mdi_importances.sort_values(ascending=True, by="mdi")
150174

151175
# %%
152176
ax = mdi_importances.plot.barh()
153177
ax.set_title("Random Forest Feature Importances (MDI)")
178+
ax.axvline(x=0, color="k", linestyle="--")
154179
ax.figure.tight_layout()
155180

156181
# %%
@@ -232,15 +257,34 @@
232257
)
233258

234259
# %%
260+
import matplotlib.pyplot as plt
261+
235262
for name, importances in zip(["train", "test"], [train_importances, test_importances]):
236263
ax = importances.plot.box(vert=False, whis=10)
237264
ax.set_title(f"Permutation Importances ({name} set)")
238265
ax.set_xlabel("Decrease in accuracy score")
239266
ax.axvline(x=0, color="k", linestyle="--")
240267
ax.figure.tight_layout()
241268

269+
plt.figure()
270+
umdi_importances = pd.Series(
271+
rf[-1].ufi_feature_importances_[sorted_importances_idx],
272+
index=feature_names[sorted_importances_idx],
273+
)
274+
ax = umdi_importances.plot.barh()
275+
ax.set_title("Debiased MDI")
276+
ax.axvline(x=0, color="k", linestyle="--")
277+
ax.figure.tight_layout()
242278
# %%
243279
# Now, we can observe that on both sets, the `random_num` and `random_cat`
244-
# features have a lower importance compared to the overfitting random forest.
245-
# However, the conclusions regarding the importance of the other features are
280+
# features have a lower permuatation importance compared to the overfitting random
281+
# forest. However, the conclusions regarding the importance of the other features are
246282
# still valid.
283+
#
284+
# These accurate permutation importances match the results obtained with oob-based
285+
# impurity methods on this new random forest.
286+
#
287+
# Do note that permutation importances are costly as they require refitting the whole
288+
# model for every permutation of each feature. When working on large datasets with
289+
# random forests, it may be preferable to use debiased impurity based feature importance
290+
# measures.

0 commit comments

Comments
 (0)
0