8000 FIX: allow colorbar mappable norm to change and do right thing · matplotlib/matplotlib@48b3510 · GitHub
[go: up one dir, main page]

Skip to content

Commit 48b3510

Browse files
committed
FIX: allow colorbar mappable norm to change and do right thing
1 parent 4432fda commit 48b3510

File tree

3 files changed

+113
-34
lines changed

3 files changed

+113
-34
lines changed

lib/matplotlib/cm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,13 @@ def set_norm(self, norm):
349349
Parameters
350350
----------
351351
norm : `.Normalize`
352+
353+
Notes
354+
-----
355+
If there are any colorbars using the mappable for this norm, setting
356+
the norm of the mappable will reset the norm, locator, and formatters
357+
on the colorbar to default.
358+
352359
"""
353360
if norm is None:
354361
norm = colors.Normalize()

lib/matplotlib/colorbar.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -387,30 +387,26 @@ def __init__(self, ax, cmap=None,
387387
self.outline = None
388388
self.patch = None
389389
self.dividers = None
390+
self.locator = None
391+
self.formatter = None
390392
self._manual_tick_data_values = None
391393

392394
if ticklocation == 'auto':
393395
ticklocation = 'bottom' if orientation == 'horizontal' else 'right'
394396
self.ticklocation = ticklocation
395397

396398
self.set_label(label)
399+
self._reset_locator_formatter_scale()
400+
397401
if np.iterable(ticks):
398402
self.locator = ticker.FixedLocator(ticks, nbins=len(ticks))
399403
else:
400404
self.locator = ticks # Handle default in _ticker()
401-
if format is None:
402-
if isinstance(self.norm, colors.LogNorm):
403-
self.formatter = ticker.LogFormatterSciNotation()
404-
elif isinstance(self.norm, colors.SymLogNorm):
405-
self.formatter = ticker.LogFormatterSciNotation(
406-
linthresh=self.norm.linthresh)
407-
else:
408-
self.formatter = ticker.ScalarFormatter()
409-
elif isinstance(format, str):
405+
406+
if isinstance(format, str):
410407
self.formatter = ticker.FormatStrFormatter(format)
411408
else:
412-
self.formatter = format # Assume it is a Formatter
413-
# The rest is in a method so we can recalculate when clim changes.
409+
self.formatter = format # Assume it is a Formatter or None
414410
self.draw_all()
415411

416412
def _extend_lower(self):
@@ -432,7 +428,6 @@ def draw_all(self):
432428
Calculate any free parameters based on the current cmap and norm,
433429
and do all the drawing.
434430
'''
435-
436431
# sets self._boundaries and self._values in real data units.
437432
# takes into account extend values:
438433
self._process_values()
@@ -449,14 +444,14 @@ def draw_all(self):
449444
if self.filled:
450445
self._add_solids(X, Y, C)
451446

447+
def set_norm(self, norm):
448+
"""
449+
set the norm of the mappable associated with this colorbar.
450+
"""
451+
self.mappable.set_norm(norm)
452+
452453
def config_axis(self):
453454
ax = self.ax
454-
if (isinstance(self.norm, colors.LogNorm)
455-
and self._use_auto_colorbar_locator()):
456-
# *both* axes are made log so that determining the
457-
# mid point is easier.
458-
ax.set_xscale('log')
459-
ax.set_yscale('log')
460455

461456
if self.orientation == 'vertical':
462457
long_axis, short_axis = ax.yaxis, ax.xaxis
@@ -504,6 +499,20 @@ def _get_ticker_locator_formatter(self):
504499
else:
505500
b = self._boundaries[self._inside]
506501
locator = ticker.FixedLocator(b, nbins=10)
502+
503+
if formatter is None:
504+
if isinstance(self.norm, colors.LogNorm):
505+
formatter = ticker.LogFormatterSciNotation()
506+
elif isinstance(self.norm, colors.SymLogNorm):
507+
formatter = ticker.LogFormatterSciNotation(
508+
linthresh=self.norm.linthresh)
509+
else:
510+
formatter = ticker.ScalarFormatter()
511+
else:
512+
formatter = self.formatter
513+
514+
self.locator = locator
515+
self.formatter = formatter
507516
_log.debug('locator: %r', locator)
508517
return locator, formatter
509518

@@ -517,6 +526,24 @@ def _use_auto_colorbar_locator(self):
517526
and ((type(self.norm) == colors.Normalize)
518527
or (type(self.norm) == colors.LogNorm)))
519528

529+
def _reset_locator_formatter_scale(self):
530+
"""
531+
Reset the locator et al to defaults. Any user-hardcoded changes
532+
need to be re-entered if this gets called (either at init, or when
533+
the mappable normal gets changed: Colorbar.update_normal)
534+
"""
535+
self.locator = None
536+
self.formatter = None
537+
if (isinstance(self.norm, colors.LogNorm)
538+
and self._use_auto_colorbar_locator()):
539+
# *both* axes are made log so that determining the
540+
# mid point is easier.
541+
self.ax.set_xscale('log')
542+
self.ax.set_yscale('log')
543+
else:
544+
self.ax.set_xscale('linear')
545+
self.ax.set_yscale('linear')
546+
520547
def update_ticks(self):
521548
"""
522549
Force the update of the ticks and ticklabels. This must be
@@ -526,7 +553,6 @@ def update_ticks(self):
526553
# get the locator and formatter. Defaults to
527554
# self.locator if not None..
528555
locator, formatter = self._get_ticker_locator_formatter()
529-
530556
if self.orientation == 'vertical':
531557
long_axis, short_axis = ax.yaxis, ax.xaxis
532558
else:
@@ -1082,7 +1108,6 @@ def __init__(self, ax, mappable, **kw):
10821108
kw['boundaries'] = CS._levels
10831109
kw['values'] = CS.cvalues
10841110
kw['extend'] = CS.extend
1085-
#kw['ticks'] = CS._levels
10861111
kw.setdefault('ticks', ticker.FixedLocator(CS.levels, nbins=10))
10871112
kw['filled'] = CS.filled
10881113
ColorbarBase.__init__(self, ax, **kw)
@@ -1105,6 +1130,7 @@ def on_mappable_changed(self, mappable):
11051130
by :func:`colorbar_factory` and should not be called manually.
11061131
11071132
"""
1133+
_log.debug('colorbar mappable changed')
11081134
self.set_cmap(mappable.get_cmap())
11091135
self.set_clim(mappable.get_clim())
11101136
self.update_normal(mappable)
@@ -1136,9 +1162,20 @@ def update_normal(self, mappable):
11361162
Update solid patches, lines, etc.
11371163
11381164
Unlike `.update_bruteforce`, this does not clear the axes. This is
1139-
meant to be called when the image or contour plot to which this
1140-
colorbar belongs changes.
1165+
meant to be called when the norm of the image or contour plot to which
1166+
this colorbar belongs changes.
1167+
1168+
This resets the locator and formatter for the axis, so if these
1169+
have been customized, they will need to be customized again.
11411170
"""
1171+
1172+
_log.debug('colorbar update normal')
1173+
self.mappable = mappable
1174+
self.set_alpha(mappable.get_alpha())
1175+
self.cmap = mappable.cmap
1176+
self.norm = mappable.norm
1177+
self._reset_locator_formatter_scale()
1178+
11421179
self.draw_all()
11431180
if isinstance(self.mappable, contour.ContourSet):
11441181
CS = self.mappable
@@ -1160,15 +1197,16 @@ def update_bruteforce(self, mappable):
11601197
# properties have been changed by methods other than the
11611198
# colorbar methods, those changes will be lost.
11621199
self.ax.cla()
1200+
self.locator = None
1201+
self.formatter = None
1202+
11631203
# clearing the axes will delete outline, patch, solids, and lines:
11641204
self.outline = None
11651205
self.patch = None
11661206
self.solids = None
11671207
self.lines = list()
11681208
self.dividers = None
1169-
self.set_alpha(mappable.get_alpha())
1170-
self.cmap = mappable.cmap
1171-
self.norm = mappable.norm
1209+
self.update_normal(mappable)
11721210
self.draw_all()
11731211
if isinstance(self.mappable, contour.ContourSet):
11741212
CS = self.mappable

lib/matplotlib/tests/test_colorbar.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from matplotlib import rc_context
55
from matplotlib.testing.decorators import image_comparison
66
import matplotlib.pyplot as plt
7-
from matplotlib.colors import BoundaryNorm, LogNorm, PowerNorm
7+
from matplotlib.colors import BoundaryNorm, LogNorm, PowerNorm, Normalize
88
from matplotlib.cm import get_cmap
9-
from matplotlib.colorbar import ColorbarBase
10-
from matplotlib.ticker import LogLocator, LogFormatter
9+
from matplotlib.colorbar import ColorbarBase, _ColorbarLogLocator
10+
from matplotlib.ticker import LogLocator, LogFormatter, FixedLocator
1111

1212

1313
def _get_cmap_norms():
@@ -424,23 +424,57 @@ def test_colorbar_renorm():
424424
fig, ax = plt.subplots()
425425
im = ax.imshow(z)
426426
cbar = fig.colorbar(im)
427+
assert np.allclose(cbar.ax.yaxis.get_majorticklocs(),
428+
np.arange(0, 120000.1, 15000))
429+
430+
cbar.set_ticks([1, 2, 3])
431+
assert isinstance(cbar.locator, FixedLocator)
427432

428433
norm = LogNorm(z.min(), z.max())
429434
im.set_norm(norm)
430-
cbar.set_norm(norm)
431-
cbar.locator = LogLocator()
432-
cbar.formatter = LogFormatter()
433-
cbar.update_normal(im)
435+
assert isinstance(cbar.locator, _ColorbarLogLocator)
436+
assert np.allclose(cbar.ax.yaxis.get_majorticklocs(),
437+
np.logspace(-8, 5, 14))
438+
# note that set_norm removes the FixedLocator...
434439
assert np.isclose(cbar.vmin, z.min())
440+
cbar.set_ticks([1, 2, 3])
441+
assert isinstance(cbar.locator, FixedLocator)
442+
assert np.allclose(cbar.ax.yaxis.get_majorticklocs(),
443+
[1.0, 2.0, 3.0])
435444

436445
norm = LogNorm(z.min() * 1000, z.max() * 1000)
437446
im.set_norm(norm)
438-
cbar.set_norm(norm)
439-
cbar.update_normal(im)
440447
assert np.isclose(cbar.vmin, z.min() * 1000)
441448
assert np.isclose(cbar.vmax, z.max() * 1000)
442449

443450

451+
def test_colorbar_format():
452+
# make sure that format is passed properly
453+
x, y = np.ogrid[-4:4:31j, -4:4:31j]
454+
z = 120000*np.exp(-x**2 - y**2)
455+
456+
fig, ax = plt.subplots()
457+
im = ax.imshow(z)
458+
cbar = fig.colorbar(im, format='%4.2e')
459+
fig.canvas.draw()
460+
assert cbar.ax.yaxis.get_ticklabels()[4].get_text() == '6.00e+04'
461+
462+
463+
def test_colorbar_scale_reset():
464+
x, y = np.ogrid[-4:4:31j, -4:4:31j]
465+
z = 120000*np.exp(-x**2 - y**2)
466+
467+
fig, ax = plt.subplots()
468+
pcm = ax.pcolormesh(z, cmap='RdBu_r', rasterized=True)
469+
cbar = fig.colorbar(pcm, ax=ax)
470+
assert cbar.ax.yaxis.get_scale() == 'linear'
471+
472+
pcm.set_norm(LogNorm(vmin=1, vmax=100))
473+
assert cbar.ax.yaxis.get_scale() == 'log'
474+
pcm.set_norm(Normalize(vmin=-20, vmax=20))
475+
assert cbar.ax.yaxis.get_scale() == 'linear'
476+
477+
444478
def test_colorbar_get_ticks():
445479
with rc_context({'_internal.classic_mode': False}):
446480

0 commit comments

Comments
 (0)
0