@@ -437,26 +437,6 @@ def draw(self, renderer):
437
437
self .offsetText .set_ha (align )
438
438
self .offsetText .draw (renderer )
439
439
440
- if self .axes ._draw_grid and len (ticks ):
441
- # Grid points where the planes meet
442
- xyz0 = np .tile (minmax , (len (ticks ), 1 ))
443
- xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
444
-
445
- # Grid lines go from the end of one plane through the plane
446
- # intersection (at xyz0) to the end of the other plane. The first
447
- # point (0) differs along dimension index-2 and the last (2) along
448
- # dimension index-1.
449
- lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
450
- lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
451
- lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
452
- self .gridlines .set_segments (lines )
453
- gridinfo = info ['grid' ]
454
- self .gridlines .set_color (gridinfo ['color' ])
455
- self .gridlines .set_linewidth (gridinfo ['linewidth' ])
456
- self .gridlines .set_linestyle (gridinfo ['linestyle' ])
457
- self .gridlines .do_3d_projection ()
458
- self .gridlines .draw (renderer )
459
-
460
440
# Draw ticks:
461
441
tickdir = self ._get_tickdir ()
462
442
tickdelta = deltas [tickdir ] if highs [tickdir ] else - deltas [tickdir ]
@@ -494,6 +474,46 @@ def draw(self, renderer):
494
474
renderer .close_group ('axis3d' )
495
475
self .stale = False
496
476
477
+ @artist .allow_rasterization
478
+ def draw_grid (self , renderer ):
479
+ if not self .axes ._draw_grid :
480
+ return
481
+
482
+ self .label ._transform = self .axes .transData
483
+ renderer .open_group ("grid3d" , gid = self .get_gid ())
484
+
485
+ ticks = self ._update_ticks ()
486
+ if len (ticks ):
487
+ # Get general axis information:
488
+ info = self ._axinfo
489
+ index = info ["i" ]
490
+
491
+ mins , maxs , tc , highs = self ._get_coord_info ()
492
+
493
+ minmax = np .where (highs , maxs , mins )
494
+ maxmin = np .where (~ highs , maxs , mins )
495
+
496
+ # Grid points where the planes meet
497
+ xyz0 = np .tile (minmax , (len (ticks ), 1 ))
498
+ xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
499
+
500
+ # Grid lines go from the end of one plane through the plane
501
+ # intersection (at xyz0) to the end of the other plane. The first
502
+ # point (0) differs along dimension index-2 and the last (2) along
503
+ # dimension index-1.
504
+ lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
505
+ lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
506
+ lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
507
+ self .gridlines .set_segments (lines )
508
+ gridinfo = info ['grid' ]
509
+ self .gridlines .set_color (gridinfo ['color' ])
510
+ self .gridlines .set_linewidth (gridinfo ['linewidth' ])
511
+ self .gridlines .set_linestyle (gridinfo ['linestyle' ])
512
+ self .gridlines .do_3d_projection ()
513
+ self .gridlines .draw (renderer )
514
+
515
+ renderer .close_group ('grid3d' )
516
+
497
517
# TODO: Get this to work (more) properly when mplot3d supports the
498
518
# transforms framework.
499
519
def get_tightbbox (self , renderer = None , * , for_layout_only = False ):
0 commit comments