-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Adds Plotting API to Partial Dependence #14646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Adds Plotting API to Partial Dependence #14646
Conversation
could we use tight layout in the examples? |
I think the possibility of providing a single axes and stealing or providing the right number of axes and reusing should be explained in the dev docs where you explained the display API |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It generally looks good.
The :func:`~sklearn.inspection.plot_partial_dependence` function returns a | ||
:class:`~sklearn.inspection.PartialDependenceDisplay` object that can be used | ||
for plotting without needing to recalculate the partial dependence. In this | ||
example we should how to plot partial dependence plots and quickly customize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
example we should how to plot partial dependence plots and quickly customize | |
example, we show how to plot partial dependence plots and how to quickly customize |
:class:`~sklearn.inspection.PartialDependenceDisplay` object that can be used | ||
for plotting without needing to recalculate the partial dependence. In this | ||
example we should how to plot partial dependence plots and quickly customize | ||
tht plot with the Visualization API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tht plot with the Visualization API. | |
the plot with the Visualization API. |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, | ||
random_state=0) | ||
|
||
hgbr = make_pipeline(StandardScaler(), HistGradientBoostingRegressor()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to use the StandardScaler. It will also speed-up the partial-dependence computation since we will use the recursive method included in the HistGradientBoostingRegressor which is not available when using a Pipeline.
# curves. | ||
# | ||
# One way to plot the curves is to place them in the same figure, with the | ||
# curves of each model on each row. First, we create a figure two axes with two |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"a figure two axes" -> is it missing a "with"?
For two-way partial dependence plots. | ||
|
||
ax : Matplotlib axes, list of Matplotlib axes or None, (default=None) | ||
By default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing the rest of the sentence ;)
from sklearn.inspection import plot_partial_dependence | ||
|
||
|
||
boston = load_boston() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be better to name it BOSTON.
Maybe this is time to use a fixture for the data here instead?
# Test partial dependence plot function on multi-output input. | ||
X, y = multioutput_regression_data | ||
clf = LinearRegression() | ||
clf.fit(X, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you using clf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ups but not to be approved yet :) wrong button
doc/developers/contributing.rst
Outdated
@@ -1724,4 +1724,39 @@ attributes: | |||
return viz.plot(ax=ax, name=name, **kwargs) | |||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``` |
doc/developers/contributing.rst
Outdated
the axes on the grid. Furthermore, the matplotlib Artists are stored in | ||
`lines_` and `contours_` where the key is the position on the grid. When a list | ||
of axes is passsed in, the `axes_`, `lines_`, and `contours_` keys is single | ||
int corresponding to the position on the passed in list of axes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you link to the advanced partial dependence example here?
@@ -0,0 +1,110 @@ | |||
""" | |||
========================================= | |||
Partial Dependence with Visualization API |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Advanced plotting with partial dependence?
Any partial dependence plot uses the visualization API by definition, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the example btw!
# curves. | ||
# | ||
# One way to plot the curves is to place them in the same figure, with the | ||
# curves of each model on each row. First, we create a figure two axes with two |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# curves of each model on each row. First, we create a figure two axes with two | |
# curves of each model on each row. First, we create a figure with two axes within two |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to think about the data structures for self.axes_
a bit more. Or at least you should clarify.
# :func:`~sklearn.inspection.PartialDependenceDisplay.plot` functions of | ||
# `hgbr_disp` and `mlp_disp`. The plot funciton will plot the two curves in the | ||
# space by the passed in axes. The resulting plot places the histogram gradient | ||
# boosting partial dependence curves on top of the multi-layer perceptron |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"places the gbrt plots in the first row and the mlp plots in the second row"?
# Another way to compare the curves is to plot them on top of each other. Here | ||
# we create a figure with one row and two columns. The axes are passed into the | ||
# :func:`~sklearn.inspection.PartialDependenceDisplay.plot` function as list, | ||
# which will plot the partial dependence curves of each model on the same axes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a sentence saying that the length of axes must be equal to the number of features we plot?
---------- | ||
bounding_ax_ : matplotlib Axes or None | ||
If `ax` is an axes or None, the `bounding_ax_` is the axes where the | ||
grid of partial dependence plots are drawn. If `ax` is list of axes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grid of partial dependence plots are drawn. If `ax` is list of axes, | |
grid of partial dependence plots are drawn. If `ax` is a list of axes, |
and jth column. If `ax` is a list of axes, `axes_[i]` is the ith item | ||
in `ax`. | ||
|
||
lines_ : matplotlib Artists |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a dict? a list? a numpy array?
grid of partial dependence plots are drawn. If `ax` is list of axes, | ||
`bounding_ax_` is None. | ||
|
||
axes_ : dict of matplotlib Axes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how can a dict have [i, j]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little weird that the shape of the attributes depends on whether we passed a list of axes or not, but I don't see a nice way around that right now.
given, it is treated as a bounding axes and a grid of partial | ||
depdendence plots will be drawn on that top of it. If a list of | ||
axes are passed, the partial dependence plots will be drawn on | ||
those axes. By default, a single bounding axes is created and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove single here?
those axes. By default, a single bounding axes is created and | ||
treated as the single axes case. | ||
|
||
n_cols : int, default=3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, I hate the way this is currently done and we should be doing what dabl is doing or https://github.com/matplotlib/grid-strategy is doing but that's for another day lol.
A figure object onto which the plots will be drawn, after the | ||
figure has been cleared. By default, a new one is created. | ||
|
||
.. deprecated:: 0.22 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. We need it for backward-compatibility though. We could make it private. I'm not sure if that's less weird?
if fig is None: | ||
_, ax = plt.subplots() | ||
else: | ||
ax = fig.add_subplot(111) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hate the matlab syntax. I would use either add_subplot(1, 1, 1)
or gca
?
# which will plot the partial dependence curves of each model on the same axes. | ||
|
||
# sphinx_gallery_thumbnail_number = 4 | ||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like it would be more natural to just call hgbr_disp
with the default arguments and use the generated axes to pass to mlp_disp
.
Right now the API might require us to do a ravel or something? But that's something that we should support. Maybe we should be using OrderedDict
s?
@@ -23,16 +23,16 @@ class PartialDependenceDisplay: | |||
|
|||
Parameters | |||
---------- | |||
pd_results : list of (ndarray, ndarray) | |||
pd_results : (ndarray, ndarray) list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list of tuples?
If a user passes in a list or an ndarray of one dimension, |
Results of `sklearn.inspection.partial_dependence` for ``features``. | ||
Each tuple corresponds to a (averaged_predictions, grid). | ||
|
||
features : list of {(int, ), (int, int)} | ||
features {(int, ), (int, int)} list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did I miss a discussion about this style?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
Made a bunch of comments. I only looked at the docstrings and the examples so far.
-
I don't think we really need the
_plot
folder, why not just havepartial_dependence_display.py
? If we must have_plot
then it would make sense to me thatplot_partial_dependence()
is there. On top of that we now have 3 files namedpartial_dependence.py
:/ -
About the
deciles
: (I'm literally discovering that now, shame on me), would you mind adding a small comment in the docstring ofplot_partial_depence
, e.g.
"The deciles of the feature values will be shown with tick marks on the x-axis " -
With the current API, is it easy to extend it in the future by allowing e.g. 3d plots instead of contours for 2way pdps?
doc/developers/plotting.rst
Outdated
|
||
When a single axes is passed in, that axes defines a space for the multiple | ||
axes to be placed. In this case, matplotlib's | ||
`gridspec.GridSpecFromSubplotSpec` can be used to split up the space:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused. Is this a suggestion of what the internals of plot_*
could/should be? If so, let's say it ;)
doc/developers/plotting.rst
Outdated
By default, the `ax` keyworld in `plot` is `None`. In this case, the single | ||
axes is created and the gridspec api is used to create the regions to plot in. | ||
|
||
For example, :func:`~sklearn.inspection.plot_partial_dependence` plots multiple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest "See for example [...] which plots multiple..."
doc/developers/plotting.rst
Outdated
axes is created and the gridspec api is used to create the regions to plot in. | ||
|
||
For example, :func:`~sklearn.inspection.plot_partial_dependence` plots multiple | ||
lines and contours using this API. The axes that is passed in or created that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The axes that is passed in or created that defines the space
That sounds a bit odd to me
doc/developers/plotting.rst
Outdated
the grid. Positions that are not used are set to `None`. Furthermore, the | ||
matplotlib Artists are stored in `lines_` and `contours_` where the key is the | ||
position on the grid. When a list of axes is passsed in, the `axes_`, `lines_`, | ||
and `contours_` keys is single int corresponding to the position on the passed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is single int -> "are integers" ?
on -> of ??
I find that whole sentence a bit hard to read.
doc/developers/plotting.rst
Outdated
and `contours_` keys is single int corresponding to the position on the passed | ||
in list of axes. | ||
|
||
Read more in :ref:`sphx_glr_auto_examples_plot_roc_curve_visualization_api.py` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be in the above section?
# Plotting partial dependence of the two models independently | ||
# =========================================================== | ||
# | ||
# Next, we plot a partial dependence curves for features "LSTAT" and "RM" for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Next, we plot a partial dependence curves for features "LSTAT" and "RM" for | |
# Next, we plot a partial dependence curves for features "LSTAT" and "RM" |
|
||
############################################################################## | ||
# The partial depdendence curves can be plotted for the multi-layer perceptron. | ||
# In this case `line_kw` was passed to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# In this case `line_kw` was passed to | |
# In this case `line_kw` is passed to |
############################################################################## | ||
# The partial depdendence curves can be plotted for the multi-layer perceptron. | ||
# In this case `line_kw` was passed to | ||
# `~sklearn.inspection.plot_partial_dependence` to change the color of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not link
# The partial depdendence curves can be plotted for the multi-layer perceptron. | ||
# In this case `line_kw` was passed to | ||
# `~sklearn.inspection.plot_partial_dependence` to change the color of the | ||
# curve and `n_cols` was set to 1 to set the number of columns to 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# curve and `n_cols` was set to 1 to set the number of columns to 1. | |
# curve and `n_cols` is set to 1. |
# | ||
# One way to plot the curves is to place them in the same figure, with the | ||
# curves of each model on each row. First, we create a figure with two axes | ||
# within two rows and one column. The two axes are passed to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the example but it's pretty condensed and it combines multiple functionalities at once so it's hard to understand how things interact independently.
For example it's a bit confusing at first how you can end up with a 2 by 2 grid while you only asked for 1 column.
I think it's missing a very simple plot with just 1 feature where you combine the curves of the 2 models. Then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a "Plotting partial dependence for one feature" section at the end. With the current API, there are two ways to plot a single feature:
Option 1
tree_disp = plot_partial_dependence(tree, X, ["LSTAT"],
feature_names=feature_names)
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"],
feature_names=feature_names,
ax=tree_disp.axes_, line_kw={"c": "red"})
Option 2
_, ax = plt.subplots()
tree_disp = plot_partial_dependence(tree, X, ["LSTAT"],
feature_names=feature_names, ax=[ax])
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"],
feature_names=feature_names,
ax=[ax], line_kw={"c": "red"})
For the example, I went with option 1. The "nicer" way to do this is:
Possible option that is being disallowed.
_, ax = plt.subplots()
tree_disp = plot_partial_dependence(tree, X, ["LSTAT"],
feature_names=feature_names, ax=ax)
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"],
feature_names=feature_names,
ax=ax, line_kw={"c": "red"})
The first call will call ax.set_visible(False)
to denote that the space has been used. The second call will see that the axes is not visible and raise an error. We can technically support this "single feature" and "single axes" case, which I think will add another layer of complexity to the API i.e. "If ax
is a single axes and len(features) == 1
, then we behave differently"
A figure object onto which the plots will be drawn, after the | ||
figure has been cleared. By default, a new one is created. | ||
|
||
.. deprecated:: 0.22 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to descr D7AE ibe this comment to others. Learn more.
Should we just not publicly document it then?
doc/developers/plotting.rst
Outdated
|
||
Scikit-learn defines a simple API for creating visualizations for machine | ||
learning. The key features of this API is to run calculations once and to have | ||
the flexibility to adjust the visualizations after the fact. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To explain the purpose of this doc we can add something like
This section is intended for developers who whish to develop or maintain plotting tools. For usage, users should refer to the
user guide <link>
I have trying to keep using this |
I won't argue more but I find that having a module with a single file isn't necessary (esp. since it creates other kinds of inconsistencies), and having 3 different files named |
I have been avoiding 3-d plots. They... add another dimension to plotting. I think current |
A figure object onto which the plots will be drawn, after the | ||
figure has been cleared. By default, a new one is created. | ||
|
||
.. deprecated:: 0.22 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.. deprecated:: 0.22 | |
.. deprecated:: 0.22 | |
``fig`` will be removed in 0.24. |
Dict with keywords passed to the `matplotlib.pyplot.contourf` | ||
call for two-way partial dependence plots. | ||
|
||
fig : Matplotlib figure object, optional (default=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fig : Matplotlib figure object, optional (default=None) | |
fig : Matplotlib figure object, default=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to not introduce a parameter which is directly deprecated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated PR to not need to add a fig
just to deprecate it.
:class:`~sklearn.inspection.PartialDependenceDisplay` object that can be used | ||
for plotting without needing to recalculate the partial dependence. In this | ||
example, we show how to plot partial dependence plots and how to quickly | ||
customize the plot with the Visualization API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
customize the plot with the Visualization API. | |
customize the plot with the visualization API. |
For two-way partial dependence plots. | ||
|
||
ax : Matplotlib axes or array-like of Matplotlib axes, default=None | ||
- If a single axes is passed in, it is treated as a bounding axes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- If a single axes is passed in, it is treated as a bounding axes | |
- If a single axis is passed in, it is treated as a bounding axes |
and a grid of partial depedendence plots will be drawn within | ||
these bounds. The `n_cols` parameter controls the number of | ||
columns in the grid. | ||
- If a array-like of axes are passed in, the partial dependence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- If a array-like of axes are passed in, the partial dependence | |
- If an array-like of axes are passed in, the partial dependence |
columns in the grid. | ||
- If a array-like of axes are passed in, the partial dependence | ||
plots will be drawn directly into these axes. | ||
- If `None`, a figure and a bounding axes is created and treated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- If `None`, a figure and a bounding axes is created and treated | |
- If `None`, a figure and a bounding axis is created and treated |
- If a array-like of axes are passed in, the partial dependence | ||
plots will be drawn directly into these axes. | ||
- If `None`, a figure and a bounding axes is created and treated | ||
as the single axes case. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as the single axes case. | |
as the single axis case. |
display = PartialDependenceDisplay(pd_results, features, feature_names, | ||
target_idx, pdp_lim, deciles) | ||
return display.plot(ax=ax, n_cols=n_cols, line_kw=line_kw, | ||
contour_kw=contour_kw, fig=fig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the introduction of fig
LGTM Thanks @thomasjpfan |
Reference Issues/PRs
Fixes #14596
What does this implement/fix? Explain your changes.
PartialDependenceDisplay
forplot_partial_dependence
.fig
option. The new API can take a single axes and place all the plots in that ax.plot_partial_dependence_visualization_api.py
, is used to demonstrate howplot_partial_dependence
how to use this axes based API.Any other comments?
This PR removes the custom padding between subplots because the padding depends on the size of the figure. When using matplotlib's constrained_layout:
The labels of the plots do not overlap. This can not be used in our examples because we are support matplotlib
1.5.1
(which does not have constrained_layout).