-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Adds Plotting API to Partial Dependence #14646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fa93505
d1dc257
1dac62c
bc45123
9d67f6a
5129289
6f5407e
8bc2433
4aa273c
5263235
045aa8f
75f8ed8
5fddc5f
396f1cb
bd121c3
422992e
ca532d7
7316ef5
9b916cf
ec47bd8
170a1b5
9ad9477
cab552a
0f5dee1
e7965d9
24ee8b0
d339328
f7e946d
480bcef
e9e5d3a
4b50a8b
f77f5d1
612a80e
a8182d8
6c3a6c9
42e0c9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,3 +16,4 @@ Developer's Guide | |
performance | ||
advanced_installation | ||
maintainer | ||
plotting |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
.. _plotting_api: | ||
|
||
================================ | ||
Developing with the Plotting API | ||
================================ | ||
|
||
Scikit-learn defines a simple API for creating visualizations for machine | ||
learning. The key features of this API is to run calculations once and to have | ||
the flexibility to adjust the visualizations after the fact. This section is | ||
intended for developers who wish to develop or maintain plotting tools. For | ||
usage, users should refer to the :ref`User Guide <visualizations>`. | ||
|
||
Plotting API Overview | ||
--------------------- | ||
|
||
This logic is encapsulated into a display object where the computed data is | ||
stored and the plotting is done in a `plot` method. The display object's | ||
`__init__` method contains only the data needed to create the visualization. | ||
The `plot` method takes in parameters that only have to do with visualization, | ||
such as a matplotlib axes. The `plot` method will store the matplotlib artists | ||
as attributes allowing for style adjustments through the display object. A | ||
`plot_*` helper function accepts parameters to do the computation and the | ||
parameters used for plotting. After the helper function creates the display | ||
object with the computed values, it calls the display's plot method. Note that | ||
the `plot` method defines attributes related to matplotlib, such as the line | ||
artist. This allows for customizations after calling the `plot` method. | ||
|
||
For example, the `RocCurveDisplay` defines the following methods and | ||
attributes:: | ||
|
||
class RocCurveDisplay: | ||
def __init__(self, fpr, tpr, roc_auc, estimator_name): | ||
... | ||
self.fpr = fpr | ||
self.tpr = tpr | ||
self.roc_auc = roc_auc | ||
self.estimator_name = estimator_name | ||
|
||
def plot(self, ax=None, name=None, **kwargs): | ||
... | ||
self.line_ = ... | ||
self.ax_ = ax | ||
self.figure_ = ax.figure_ | ||
|
||
def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, | ||
drop_intermediate=True, response_method="auto", | ||
name=None, ax=None, **kwargs): | ||
# do computation | ||
viz = RocCurveDisplay(fpr, tpr, roc_auc, | ||
estimator.__class__.__name__) | ||
return viz.plot(ax=ax, name=name, **kwargs) | ||
|
||
Read more in :ref:`sphx_glr_auto_examples_plot_roc_curve_visualization_api.py` | ||
and the :ref:`User Guide <visualizations>`. | ||
|
||
Plotting with Multiple Axes | ||
--------------------------- | ||
|
||
Some of the plotting tools like | ||
:func:`~sklearn.inspection.plot_partial_dependence` and | ||
:class:`~sklearn.inspection.PartialDependenceDisplay` support plottong on | ||
multiple axes. Two different scenarios are supported: | ||
|
||
1. If a list of axes is passed in, `plot` will check if the number of axes is | ||
consistent with the number of axes it expects and then draws on those axes. 2. | ||
If a single axes is passed in, that axes defines a space for multiple axes to | ||
be placed. In this case, we suggest using matplotlib's | ||
`~matplotlib.gridspec.GridSpecFromSubplotSpec` to split up the space:: | ||
|
||
import matplotlib.pyplot as plt | ||
from matplotlib.gridspec import GridSpecFromSubplotSpec | ||
|
||
fig, ax = plt.subplots() | ||
gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec()) | ||
|
||
ax_top_left = fig.add_subplot(gs[0, 0]) | ||
ax_top_right = fig.add_subplot(gs[0, 1]) | ||
ax_bottom = fig.add_subplot(gs[1, :]) | ||
|
||
By default, the `ax` keyword in `plot` is `None`. In this case, the single | ||
axes is created and the gridspec api is used to create the regions to plot in. | ||
|
||
See for example, :func:`~sklearn.inspection.plot_partial_dependence` which | ||
plots multiple lines and contours using this API. The axes defining the | ||
bounding box is saved in a `bounding_ax_` attribute. The individual axes | ||
created are stored in an `axes_` ndarray, corresponding to the axes position on | ||
the grid. Positions that are not used are set to `None`. Furthermore, the | ||
matplotlib Artists are stored in `lines_` and `contours_` where the key is the | ||
position on the grid. When a list of axes is passed in, the `axes_`, `lines_`, | ||
and `contours_` is a 1d ndarray corresponding to the list of axes passed in. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
""" | ||
========================================= | ||
Advanced Plotting With Partial Dependence | ||
========================================= | ||
The :func:`~sklearn.inspection.plot_partial_dependence` function returns a | ||
:class:`~sklearn.inspection.PartialDependenceDisplay` object that can be used | ||
for plotting without needing to recalculate the partial dependence. In this | ||
example, we show how to plot partial dependence plots and how to quickly | ||
customize the plot with the visualization API. | ||
|
||
.. note:: | ||
|
||
See also :ref:`sphx_glr_auto_examples_plot_roc_curve_visualization_api.py` | ||
|
||
""" | ||
print(__doc__) | ||
|
||
import matplotlib.pyplot as plt | ||
from sklearn.datasets import load_boston | ||
from sklearn.neural_network import MLPRegressor | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.pipeline import make_pipeline | ||
from sklearn.tree import DecisionTreeRegressor | ||
from sklearn.inspection import plot_partial_dependence | ||
|
||
|
||
############################################################################## | ||
# Train models on the boston housing price dataset | ||
# ================================================ | ||
# | ||
# First, we train a decision tree and a multi-layer perceptron on the boston | ||
# housing price dataset. | ||
|
||
boston = load_boston() | ||
X, y = boston.data, boston.target | ||
feature_names = boston.feature_names | ||
|
||
tree = DecisionTreeRegressor() | ||
mlp = make_pipeline(StandardScaler(), | ||
MLPRegressor(hidden_layer_sizes=(100, 100), | ||
tol=1e-2, max_iter=500, random_state=0)) | ||
tree.fit(X, y) | ||
mlp.fit(X, y) | ||
|
||
|
||
############################################################################## | ||
# Plotting partial dependence for two features | ||
# ============================================ | ||
# | ||
# We plot partial dependence curves for features "LSTAT" and "RM" for | ||
# the decision tree. With two features, | ||
# :func:`~sklearn.inspection.plot_partial_dependence` expects to plot two | ||
# curves. Here the plot function place a grid of two plots using the space | ||
# defined by `ax` . | ||
fig, ax = plt.subplots(figsize=(12, 6)) | ||
ax.set_title("Decision Tree") | ||
tree_disp = plot_partial_dependence(tree, X, ["LSTAT", "RM"], | ||
feature_names=feature_names, ax=ax) | ||
|
||
############################################################################## | ||
# The partial depdendence curves can be plotted for the multi-layer perceptron. | ||
# In this case, `line_kw` is passed to | ||
# :func:`~sklearn.inspection.plot_partial_dependence` to change the color of | ||
# the curve. | ||
fig, ax = plt.subplots(figsize=(12, 6)) | ||
ax.set_title("Multi-layer Perceptron") | ||
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT", "RM"], | ||
feature_names=feature_names, ax=ax, | ||
line_kw={"c": "red"}) | ||
|
||
############################################################################## | ||
# Plotting partial dependence of the two models together | ||
# ====================================================== | ||
# | ||
# The `tree_disp` and `mlp_disp` | ||
# :class:`~sklearn.inspection.PartialDependenceDisplay` objects contain all the | ||
# computed information needed to recreate the partial dependence curves. This | ||
# means we can easily create additional plots without needing to recompute the | ||
# curves. | ||
# | ||
# One way to plot the curves is to place them in the same figure, with the | ||
# curves of each model on each row. First, we create a figure with two axes | ||
# within two rows and one column. The two axes are passed to the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the example but it's pretty condensed and it combines multiple functionalities at once so it's hard to understand how things interact independently. For example it's a bit confusing at first how you can end up with a 2 by 2 grid while you only asked for 1 column. I think it's missing a very simple plot with just 1 feature where you combine the curves of the 2 models. Then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a "Plotting partial dependence for one feature" section at the end. With the current API, there are two ways to plot a single feature: Option 1tree_disp = plot_partial_dependence(tree, X, ["LSTAT"],
feature_names=feature_names)
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"],
feature_names=feature_names,
ax=tree_disp.axes_, line_kw={"c": "red"}) Option 2_, ax = plt.subplots()
tree_disp = plot_partial_dependence(tree, X, ["LSTAT"],
feature_names=feature_names, ax=[ax])
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"],
feature_names=feature_names,
ax=[ax], line_kw={"c": "red"}) For the example, I went with option 1. The "nicer" way to do this is: Possible option that is being disallowed._, ax = plt.subplots()
tree_disp = plot_partial_dependence(tree, X, ["LSTAT"],
feature_names=feature_names, ax=ax)
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"],
feature_names=feature_names,
ax=ax, line_kw={"c": "red"}) The first call will call |
||
# :func:`~sklearn.inspection.PartialDependenceDisplay.plot` functions of | ||
# `tree_disp` and `mlp_disp`. The given axes will be used by the plotting | ||
# function to draw the partial dependence. The resulting plot places the | ||
# decision tree partial dependence curves in the first row of the | ||
# multi-layer perceptron in the second row. | ||
|
||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10)) | ||
tree_disp.plot(ax=ax1) | ||
ax1.set_title("Decision Tree") | ||
mlp_disp.plot(ax=ax2, line_kw={"c": "red"}) | ||
ax2.set_title("Multi-layer Perceptron") | ||
|
||
############################################################################## | ||
# Another way to compare the curves is to plot them on top of each other. Here, | ||
# we create a figure with one row and two columns. The axes are passed into the | ||
# :func:`~sklearn.inspection.PartialDependenceDisplay.plot` function as a list, | ||
# which will plot the partial dependence curves of each model on the same axes. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a sentence saying that the length of axes must be equal to the number of features we plot? |
||
# The length of the axes list must be equal to the number of plots drawn. | ||
|
||
# Sets this image as the thumbnail for sphinx gallery | ||
# sphinx_gallery_thumbnail_number = 4 | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 6)) | ||
tree_disp.plot(ax=[ax1, ax2], line_kw={"label": "Decision Tree"}) | ||
mlp_disp.plot(ax=[ax1, ax2], line_kw={"label": "Multi-layer Perceptron", | ||
"c": "red"}) | ||
ax1.legend() | ||
ax2.legend() | ||
|
||
############################################################################## | ||
# `tree_disp.axes_` is a numpy array container the axes used to draw the | ||
# partial dependence plots. This can be passed to `mlp_disp` to have the same | ||
# affect of drawing the plots on top of each other. Furthermore, the | ||
# `mlp_disp.figure_` stores the figure, which allows for resizing the figure | ||
# after calling `plot`. | ||
|
||
tree_disp.plot(line_kw={"label": "Decision Tree"}) | ||
mlp_disp.plot(line_kw={"label": "Multi-layer Perceptron", "c": "red"}, | ||
ax=tree_disp.axes_) | ||
tree_disp.figure_.set_size_inches(10, 6) | ||
tree_disp.axes_[0, 0].legend() | ||
tree_disp.axes_[0, 1].legend() | ||
plt.show() | ||
|
||
|
||
############################################################################## | ||
# Plotting partial dependence for one feature | ||
# =========================================== | ||
# | ||
# Here, we plot the partial dependence curves for a single feature, "LSTAT", on | ||
# the same axes. In this case, `tree_disp.axes_` is passed into the second | ||
# plot function. | ||
tree_disp = plot_partial_dependence(tree, X, ["LSTAT"], | ||
feature_names=feature_names) | ||
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"], | ||
feature_names=feature_names, | ||
ax=tree_disp.axes_, line_kw={"c": "red"}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
"""The :mod:`sklearn.inspection` module includes tools for model inspection.""" | ||
from .partial_dependence import partial_dependence | ||
from .partial_dependence import plot_partial_dependence | ||
from .partial_dependence import PartialDependenceDisplay | ||
from .permutation_importance import permutation_importance | ||
|
||
__all__ = [ | ||
'partial_dependence', | ||
'plot_partial_dependence', | ||
'permutation_importance' | ||
'permutation_importance', | ||
'PartialDependenceDisplay' | ||
] |
Uh oh!
There was an error while loading. Please reload this page.