@@ -391,21 +391,36 @@ class PartialDependenceDisplay:
391
391
axes_ : ndarray of matplotlib Axes
392
392
If `ax` is an axes or None, `axes_[i, j]` is the axes on the i-th row
393
393
and j-th column. If `ax` is a list of axes, `axes_[i]` is the i-th item
394
- in `ax`. Elements that are None corresponds to a nonexisting axes in
394
+ in `ax`. Elements that are None correspond to a nonexisting axes in
395
395
that position.
396
396
397
397
lines_ : ndarray of matplotlib Artists
398
- If `ax` is an axes or None, `line_ [i, j]` is the partial dependence
398
+ If `ax` is an axes or None, `lines_ [i, j]` is the partial dependence
399
399
curve on the i-th row and j-th column. If `ax` is a list of axes,
400
400
`lines_[i]` is the partial dependence curve corresponding to the i-th
401
- item in `ax`. Elements that are None corresponds to a nonexisting axes
401
+ item in `ax`. Elements that are None correspond to a nonexisting axes
402
402
or an axes that does not include a line plot.
403
403
404
+ deciles_vlines_ : ndarray of matplotlib LineCollection
405
+ If `ax` is an axes or None, `vlines_[i, j]` is the line collection
406
+ representing the x axis deciles of the i-th row and j-th column. If
407
+ `ax` is a list of axes, `vlines_[i]` corresponds to the i-th item in
408
+ `ax`. Elements that are None correspond to a nonexisting axes or an
409
+ axes that does not include a PDP plot.
410
+ .. versionadded:: 0.23
411
+ deciles_hlines_ : ndarray of matplotlib LineCollection
412
+ If `ax` is an axes or None, `vlines_[i, j]` is the line collection
413
+ representing the y axis deciles of the i-th row and j-th column. If
414
+ `ax` is a list of axes, `vlines_[i]` corresponds to the i-th item in
415
+ `ax`. Elements that are None correspond to a nonexisting axes or an
416
+ axes that does not include a 2-way plot.
417
+ .. versionadded:: 0.23
418
+
404
419
contours_ : ndarray of matplotlib Artists
405
420
If `ax` is an axes or None, `contours_[i, j]` is the partial dependence
406
421
plot on the i-th row and j-th column. If `ax` is a list of axes,
407
422
`contours_[i]` is the partial dependence plot corresponding to the i-th
408
- item in `ax`. Elements that are None corresponds to a nonexisting axes
423
+ item in `ax`. Elements that are None correspond to a nonexisting axes
409
424
or an axes that does not include a contour plot.
410
425
411
426
figure_ : matplotlib Figure
@@ -490,8 +505,6 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
490
505
n_rows = int (np .ceil (n_features / float (n_cols )))
491
506
492
507
self .axes_ = np .empty ((n_rows , n_cols ), dtype = np .object )
493
- self .lines_ = np .empty ((n_rows , n_cols ), dtype = np .object )
494
- self .contours_ = np .empty ((n_rows , n_cols ), dtype = np .object )
495
508
496
509
axes_ravel = self .axes_ .ravel ()
497
510
@@ -514,14 +527,20 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
514
527
self .bounding_ax_ = None
515
528
self .figure_ = ax .ravel ()[0 ].figure
516
529
self .axes_ = ax
517
- self .lines_ = np .empty_like (ax , dtype = np .object )
518
- self .contours_ = np .empty_like (ax , dtype = np .object )
519
530
520
531
# create contour levels for two-way plots
521
532
if 2 in self .pdp_lim :
522
533
Z_level = np .linspace (* self .pdp_lim [2 ], num = 8 )
534
+
535
+ self .lines_ = np .empty_like (self .axes_ , dtype = np .object )
536
+ self .contours_ = np .empty_like (self .axes_ , dtype = np .object )
537
+ self .deciles_vlines_ = np .empty_like (self .axes_ , dtype = np .object )
538
+ self .deciles_hlines_ = np .empty_like (self .axes_ , dtype = np .object )
539
+ # Create 1d views of these 2d arrays for easy indexing
523
540
lines_ravel = self .lines_ .ravel (order = 'C' )
524
541
contours_ravel = self .contours_ .ravel (order = 'C' )
542
+ vlines_ravel = self .deciles_vlines_ .ravel (order = 'C' )
543
+ hlines_ravel = self .deciles_hlines_ .ravel (order = 'C' )
525
544
526
545
for i , axi , fx , (avg_preds , values ) in zip (count (),
527
546
self .axes_ .ravel (),
@@ -547,8 +566,8 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
547
566
trans = transforms .blended_transform_factory (axi .transData ,
548
567
axi .transAxes )
549
568
ylim = axi .get_ylim ()
550
- axi .vlines (self .deciles [fx [0 ]], 0 , 0.05 , transform = trans ,
551
- color = 'k' )
569
+ vlines_ravel [ i ] = axi .vlines (self .deciles [fx [0 ]], 0 , 0.05 ,
570
+ transform = trans , color = 'k' )
552
571
axi .set_ylim (ylim )
553
572
554
573
# Set xlabel if it is not already set
@@ -566,8 +585,8 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
566
585
trans = transforms .blended_transform_factory (axi .transAxes ,
567
586
axi .transData )
568
587
xlim = axi .get_xlim ()
569
- axi .hlines (self .deciles [fx [1 ]], 0 , 0.05 , transform = trans ,
570
- color = 'k' )
588
+ hlines_ravel [ i ] = axi .hlines (self .deciles [fx [1 ]], 0 , 0.05 ,
589
+ transform = trans , color = 'k' )
571
590
# hline erases xlim
572
591
axi .set_ylabel (self .feature_names [fx [1 ]])
573
592
axi .set_xlim (xlim )
0 commit comments