8000 ENH add vlines_ attribute to PDP Display to hide deciles (#15785) · scikit-learn/scikit-learn@b4757f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit b4757f7

Browse files
authored
ENH add vlines_ attribute to PDP Display to hide deciles (#15785)
1 parent dc0cc6e commit b4757f7

File tree

3 files changed

+57
-12
lines changed

3 files changed

+57
-12
lines changed

doc/whats_new/v0.23.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,13 @@ Changelog
373373
:class:`neural_network.MLPClassifier` by clipping the probabilities.
374374
:pr:`16117` by `Thomas Fan`_.
375375

376+
:mod:`sklearn.inspection`
377+
.........................
378+
379+
- |Enhancement| :class:`inspection.PartialDependenceDisplay` now exposes the
380+
deciles lines as attributes so they can be hidden or customized. :pr:`15785`
381+
by `Nicolas Hug`_
382+
376383
:mod:`sklearn.preprocessing`
377384
............................
378385

sklearn/inspection/_plot/partial_dependence.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -391,21 +391,36 @@ class PartialDependenceDisplay:
391391
axes_ : ndarray of matplotlib Axes
392392
If `ax` is an axes or None, `axes_[i, j]` is the axes on the i-th row
393393
and j-th column. If `ax` is a list of axes, `axes_[i]` is the i-th item
394-
in `ax`. Elements that are None corresponds to a nonexisting axes in
394+
in `ax`. Elements that are None correspond to a nonexisting axes in
395395
that position.
396396
397397
lines_ : ndarray of matplotlib Artists
398-
If `ax` is an axes or None, `line_[i, j]` is the partial dependence
398+
If `ax` is an axes or None, `lines_[i, j]` is the partial dependence
399399
curve on the i-th row and j-th column. If `ax` is a list of axes,
400400
`lines_[i]` is the partial dependence curve corresponding to the i-th
401-
item in `ax`. Elements that are None corresponds to a nonexisting axes
401+
item in `ax`. Elements that are None correspond to a nonexisting axes
402402
or an axes that does not include a line plot.
403403
404+
deciles_vlines_ : ndarray of matplotlib LineCollection
405+
If `ax` is an axes or None, `vlines_[i, j]` is the line collection
406+
representing the x axis deciles of the i-th row and j-th column. If
407+
`ax` is a list of axes, `vlines_[i]` corresponds to the i-th item in
408+
`ax`. Elements that are None correspond to a nonexisting axes or an
409+
axes that does not include a PDP plot.
410+
.. versionadded:: 0.23
411+
deciles_hlines_ : ndarray of matplotlib LineCollection
412+
If `ax` is an axes or None, `vlines_[i, j]` is the line collection
413+
representing the y axis deciles of the i-th row and j-th column. If
414+
`ax` is a list of axes, `vlines_[i]` corresponds to the i-th item in
415+
`ax`. Elements that are None correspond to a nonexisting axes or an
416+
axes that does not include a 2-way plot.
417+
.. versionadded:: 0.23
418+
404419
contours_ : ndarray of matplotlib Artists
405420
If `ax` is an axes or None, `contours_[i, j]` is the partial dependence
406421
plot on the i-th row and j-th column. If `ax` is a list of axes,
407422
`contours_[i]` is the partial dependence plot corresponding to the i-th
408-
item in `ax`. Elements that are None corresponds to a nonexisting axes
423+
item in `ax`. Elements that are None correspond to a nonexisting axes
409424
or an axes that does not include a contour plot.
410425
411426
figure_ : matplotlib Figure
@@ -490,8 +505,6 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
490505
n_rows = int(np.ceil(n_features / float(n_cols)))
491506

492507
self.axes_ = np.empty((n_rows, n_cols), dtype=np.object)
493-
self.lines_ = np.empty((n_rows, n_cols), dtype=np.object)
494-
self.contours_ = np.empty((n_rows, n_cols), dtype=np.object)
495508

496509
axes_ravel = self.axes_.ravel()
497510

@@ -514,14 +527,20 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
514527
self.bounding_ax_ = None
515528
self.figure_ = ax.ravel()[0].figure
516529
self.axes_ = ax
517-
self.lines_ = np.empty_like(ax, dtype=np.object)
518-
self.contours_ = np.empty_like(ax, dtype=np.object)
519530

520531
# create contour levels for two-way plots
521532
if 2 in self.pdp_lim:
522533
Z_level = np.linspace(*self.pdp_lim[2], num=8)
534+
535+
self.lines_ = np.empty_like(self.axes_, dtype=np.object)
536+
self.contours_ = np.empty_like(self.axes_, dtype=np.object)
537+
self.deciles_vlines_ = np.empty_like(self.axes_, dtype=np.object)
538+
self.deciles_hlines_ = np.empty_like(self.axes_, dtype=np.object)
539+
# Create 1d views of these 2d arrays for easy indexing
523540
lines_ravel = self.lines_.ravel(order='C')
524541
contours_ravel = self.contours_.ravel(order='C')
542+
vlines_ravel = self.deciles_vlines_.ravel(order='C')
543+
hlines_ravel = self.deciles_hlines_.ravel(order='C')
525544

526545
for i, axi, fx, (avg_preds, values) in zip(count(),
527546
self.axes_.ravel(),
@@ -547,8 +566,8 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
547566
trans = transforms.blended_transform_factory(axi.transData,
548567
axi.transAxes)
549568
ylim = axi.get_ylim()
550-
axi.vlines(self.deciles[fx[0]], 0, 0.05, transform=trans,
551-
color='k')
569+
vlines_ravel[i] = axi.vlines(self.deciles[fx[0]], 0, 0.05,
570+
transform=trans, color='k')
552571
axi.set_ylim(ylim)
553572

554573
# Set xlabel if it is not already set
@@ -566,8 +585,8 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
566585
trans = transforms.blended_transform_factory(axi.transAxes,
567586
axi.transData)
568587
xlim = axi.get_xlim()
569-
axi.hlines(self.deciles[fx[1]], 0, 0.05, transform=trans,
570-
color='k')
588+
hlines_ravel[i] = axi.hlines(self.deciles[fx[1]], 0, 0.05,
589+
transform=trans, color='k')
571590
# hline erases xlim
572591
axi.set_ylabel(self.feature_names[fx[1]])
573592
axi.set_xlim(xlim)

sklearn/inspection/_plot/tests/test_plot_partial_dependence.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,20 @@ def test_plot_partial_dependence(grid_resolution, pyplot, clf_boston, boston):
5151
assert disp.axes_.shape == (1, 3)
5252
assert disp.lines_.shape == (1, 3)
5353
assert disp.contours_.shape == (1, 3)
54+
assert disp.deciles_vlines_.shape == (1, 3)
55+
assert disp.deciles_hlines_. F438 shape == (1, 3)
5456

5557
assert disp.lines_[0, 2] is None
5658
assert disp.contours_[0, 0] is None
5759
assert disp.contours_[0, 1] is None
5860

61+
# deciles lines: always show on xaxis, only show on yaxis if 2-way PDP
62+
for i in range(3):
63+
assert disp.deciles_vlines_[0, i] is not None
64+
assert disp.deciles_hlines_[0, 0] is None
65+
assert disp.deciles_hlines_[0, 1] is None
66+
assert disp.deciles_hlines_[0, 2] is not None
67+
5968
assert disp.features == [(0, ), (1, ), (0, 1)]
6069
assert np.all(disp.feature_names == feature_names)
6170
assert len(disp.deciles) == 2
@@ -132,9 +141,15 @@ def test_plot_partial_dependence_str_features(pyplot, clf_boston, boston,
132141
assert disp.axes_.shape == (2, 1)
133142
assert disp.lines_.shape == (2, 1)
134143
assert disp.contours_.shape == (2, 1)
144+
assert disp.deciles_vlines_.shape == (2, 1)
145+
assert disp.deciles_hlines_.shape == (2, 1)
135146

136147
assert disp.lines_[0, 0] is None
148+
assert disp.deciles_vlines_[0, 0] is not None
149+
assert disp.deciles_hlines_[0, 0] is not None
137150
assert disp.contours_[1, 0] is None
151+
assert disp.deciles_hlines_[1, 0] is None
152+
assert disp.deciles_vlines_[1, 0] is not None
138153

139154
# line
140155
ax = disp.axes_[1, 0]
@@ -309,6 +324,8 @@ def test_plot_partial_dependence_multiclass(pyplot):
309324
assert disp_target_0.axes_.shape == (1, 2)
310325
assert disp_target_0.lines_.shape == (1, 2)
311326
assert disp_target_0.contours_.shape == (1, 2)
327+
assert disp_target_0.deciles_vlines_.shape == (1, 2)
328+
assert disp_target_0.deciles_hlines_.shape == (1, 2)
312329
assert all(c is None for c in disp_target_0.contours_.flat)
313330
assert disp_target_0.target_idx == 0
314331

@@ -323,6 +340,8 @@ def test_plot_partial_dependence_multiclass(pyplot):
323340
assert disp_symbol.axes_.shape == (1, 2)
324341
assert disp_symbol.lines_.shape == (1, 2)
325342
assert disp_symbol.contours_.shape == (1, 2)
343+
assert disp_symbol.deciles_vlines_.shape == (1, 2)
344+
assert disp_symbol.deciles_hlines_.shape == (1, 2)
326345
assert all(c is None for c in disp_symbol.contours_.flat)
327346
assert disp_symbol.target_idx == 0
328347

0 commit comments

Comments
 (0)
0