8000 Set norms using scale names. · matplotlib/matplotlib@3c66118 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3c66118

Browse files
committed
Set norms using scale names.
1 parent 28e5798 commit 3c66118

File tree

4 files changed

+110
-29
lines changed

4 files changed

+110
-29
lines changed

lib/matplotlib/cm.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
"""
1717

1818
from collections.abc import Mapping, MutableMapping
19+
import functools
1920

2021
import numpy as np
2122
from numpy import ma
2223

2324
import matplotlib as mpl
24-
from matplotlib import _api, colors, cbook
25+
from matplotlib import _api, colors, cbook, scale
2526
from matplotlib._cm import datad
2627
from matplotlib._cm_listed import cmaps as cmaps_listed
2728

@@ -331,6 +332,34 @@ def unregister_cmap(name):
331332
return _cmap_registry.pop(name)
332333

333334

335+
def _auto_norm_from_scale(scale_cls):
336+
"""
337+
Automatically generate a norm class from *scale_cls*.
338+
339+
This differs from `.colors.make_norm_from_scale` in the following points:
340+
341+
- This function is not a class decorator, but directly returns a norm class
342+
(as if decorating `.Normalize`).
343+
- The scale is automatically constructed with ``nonpositive="mask"``, if it
344+
supports such a parameter, to work around the difference in defaults
345+
between standard scales (which use "clip") and norms (which use "mask").
346+
347+
Note that ``make_norm_from_scale`` caches the generated norm classes and
348+
reuses them for later calls. For example, ``_auto_norm_from_scale("log")
349+
== LogNorm``.
350+
"""
351+
# Actually try to construct an instance, to verify whether
352+
# ``nonpositive="mask"`` is supported.
353+
try:
354+
norm = colors.make_norm_from_scale(
355+
functools.partial(scale_cls, nonpositive="mask"))(
356+
colors.Normalize)()
357+
except TypeError:
358+
norm = colors.make_norm_from_scale(scale_cls)(
359+
colors.Normalize)()
360+
return type(norm)
361+
362+
334363
class ScalarMappable:
335364
"""
336365
A mixin class to map scalar data to RGBA.
@@ -341,12 +370,13 @@ class ScalarMappable:
341370

342371
def __init__(self, norm=None, cmap=None):
343372
"""
344-
345373
Parameters
346374
----------
347-
norm : `matplotlib.colors.Normalize` (or subclass thereof)
375+
norm : `.Normalize` (or subclass thereof) or str or None
348376
The normalizing object which scales data, typically into the
349377
interval ``[0, 1]``.
378+
If a `str`, a `.Normalize` subclass is dynamically generated based
379+
on the scale with the corresponding name.
350380
If *None*, *norm* defaults to a *colors.Normalize* object which
351381
initializes its scaling based on the first data processed.
352382
cmap : str or `~matplotlib.colors.Colormap`
@@ -376,11 +406,11 @@ def _scale_norm(self, norm, vmin, vmax):
376406
"""
377407
if vmin is not None or vmax is not None:
378408
self.set_clim(vmin, vmax)
379-
if norm is not None:
409+
if isinstance(norm, colors.Normalize):
380410
raise ValueError(
381-
"Passing parameters norm and vmin/vmax simultaneously is "
382-
"not supported. Please pass vmin/vmax directly to the "
383-
"norm when creating it.")
411+
"Passing a Normalize instance simultaneously with "
412+
"vmin/vmax is not supported. Please pass vmin/vmax "
413+
"directly to the norm when creating it.")
384414

385415
# always resolve the autoscaling so we have concrete limits
386416
# rather than deferring to draw time.
@@ -554,9 +584,18 @@ def norm(self):
554584

555585
@norm.setter
556586
def norm(self, norm):
557-
_api.check_isinstance((colors.Normalize, None), norm=norm)
587+
_api.check_isinstance((colors.Normalize, str, None), norm=norm)
558588
if norm is None:
559589
norm = colors.Normalize()
590+
elif isinstance(norm, str):
591+
# case-insensitive, consistently with scale_factory.
592+
try:
593+
scale_cls = scale._scale_mapping[norm.lower()]
594+
except KeyError:
595+
raise ValueError(
596+
"Invalid norm str name; the following values are "
597+
"supported: {}".format(", ".join(scale._scale_mapping)))
598+
norm = _auto_norm_from_scale(scale_cls)()
560599

561600
if norm is self.norm:
562601
# We aren't updating anything
@@ -578,7 +617,7 @@ def set_norm(self, norm):
578617
579618
Parameters
580619
----------
581-
norm : `.Normalize` or None
620+
norm : `.Normalize` or str or None
582621
583622
Notes
584623
-----

lib/matplotlib/colors.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,35 +1505,54 @@ class norm_cls(Normalize):
15051505
if base_norm_cls is None:
15061506
return functools.partial(make_norm_from_scale, scale_cls, init=init)
15071507

1508+
if isinstance(scale_cls, functools.partial):
1509+
scale_args = scale_cls.args
1510+
scale_kwargs_items = tuple(scale_cls.keywords.items())
1511+
scale_cls = scale_cls.func
1512+
else:
1513+
scale_args = scale_kwargs_items = ()
1514+
15081515
if init is None:
15091516
def init(vmin=None, vmax=None, clip=False): pass
15101517

15111518
return _make_norm_from_scale(
1512-
scale_cls, base_norm_cls, inspect.signature(init))
1519+
scale_cls, scale_args, scale_kwargs_items,
1520+
base_norm_cls, inspect.signature(init))
15131521

15141522

15151523
@functools.lru_cache(None)
1516-
def _make_norm_from_scale(scale_cls, base_norm_cls, bound_init_signature):
1524+
def _make_norm_from_scale(
1525+
scale_cls, scale_args, scale_kwargs_items,
1526+
base_norm_cls, bound_init_signature,
1527+
):
15171528
"""
15181529
Helper for `make_norm_from_scale`.
15191530
1520-
This function is split out so that it takes a signature object as third
1521-
argument (as signatures are picklable, contrary to arbitrary lambdas);
1522-
caching is also used so that different unpickles reuse the same class.
1531+
This function is split out to enable caching (in particular so that
1532+
different unpickles reuse the same class). In order to do so,
1533+
1534+
- ``functools.partial`` *scale_cls* is expanded into ``func, args, kwargs``
1535+
to allow memoizing returned norms (partial instances always compare
1536+
unequal, but we can check identity based on ``func, args, kwargs``;
1537+
- *init* is replaced by *init_signature*, as signatures are picklable,
1538+
unlike to arbitrary lambdas.
15231539
"""
15241540

15251541
class Norm(base_norm_cls):
15261542
def __reduce__(self):
15271543
return (_picklable_norm_constructor,
1528-
(scale_cls, base_norm_cls, bound_init_signature),
1544+
(scale_cls, scale_args, scale_kwargs_items,
1545+
base_norm_cls, bound_init_signature),
15291546
self.__dict__)
15301547

15311548
def __init__(self, *args, **kwargs):
15321549
ba = bound_init_signature.bind(*args, **kwargs)
15331550
ba.apply_defaults()
15341551
super().__init__(
15351552
**{k: ba.arguments.pop(k) for k in ["vmin", "vmax", "clip"]})
1536-
self._scale = scale_cls(axis=None, **ba.arguments)
1553+
self._scale = functools.partial(
1554+
scale_cls, *scale_args, **dict(scale_kwargs_items))(
1555+
axis=None, **ba.arguments)
15371556
self._trf = self._scale.get_transform()
15381557

15391558
__init__.__signature__ = bound_init_signature.replace(parameters=[
@@ -1587,12 +1606,12 @@ def autoscale_None(self, A):
15871606
in_trf_domain = np.extract(np.isfinite(self._trf.transform(A)), A)
15881607
return super().autoscale_None(in_trf_domain)
15891608

1590-
Norm.__name__ = (
1591-
f"{scale_cls.__name__}Norm" if base_norm_cls is Normalize
1592-
else base_norm_cls.__name__)
1593-
Norm.__qualname__ = (
1594-
f"{scale_cls.__qualname__}Norm" if base_norm_cls is Normalize
1595-
else base_norm_cls.__qualname__)
1609+
if base_norm_cls is Normalize:
1610+
Norm.__name__ = f"{scale_cls.__name__}Norm"
1611+
Norm.__qualname__ = f"{scale_cls.__qualname__}Norm"
1612+
else:
1613+
Norm.__name__ = base_norm_cls.__name__
1614+
Norm.__qualname__ = base_norm_cls.__qualname__
15961615
Norm.__module__ = base_norm_cls.__module__
15971616
Norm.__doc__ = base_norm_cls.__doc__
15981617

@@ -1637,9 +1656,10 @@ def forward(values: array-like) -> array-like
16371656
"""
16381657

16391658

1640-
@make_norm_from_scale(functools.partial(scale.LogScale, nonpositive="mask"))
1641-
class LogNorm(Normalize):
1642-
"""Normalize a given value to the 0-1 range on a log scale."""
1659+
LogNorm = make_norm_from_scale(
1660+
functools.partial(scale.LogScale, nonpositive="mask"))(Normalize)
1661+
LogNorm.__name__ = LogNorm.__qualname__ = "LogNorm"
1662+
LogNorm.__doc__ = "Normalize a given value to the 0-1 range on a log scale."
16431663

16441664

16451665
@make_norm_from_scale(

lib/matplotlib/tests/test_axes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -923,8 +923,8 @@ def test_imshow_norm_vminvmax():
923923
a = [[1, 2], [3, 4]]
924924
ax = plt.axes()
925925
with pytest.raises(ValueError,
926-
match="Passing parameters norm and vmin/vmax "
927-
"simultaneously is not supported."):
926+
match="Passing a Normalize instance simultaneously "
927+
"with vmin/vmax is not supported."):
928928
ax.imshow(a, norm=mcolors.Normalize(-10, 10), vmin=0, vmax=5)
929929

930930

@@ -2279,8 +2279,8 @@ def test_scatter_norm_vminvmax(self):
22792279
x = [1, 2, 3]
22802280
ax = plt.axes()
22812281
with pytest.raises(ValueError,
2282-
match="Passing parameters norm and vmin/vmax "
2283-
"simultaneously is not supported."):
2282+
match="Passing a Normalize instance simultaneously "
2283+
"with vmin/vmax is not supported."):
22842284
ax.scatter(x, x, c=x, norm=mcolors.Normalize(-10, 10),
22852285
vmin=0, vmax=5)
22862286

lib/matplotlib/tests/test_image.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,3 +1376,25 @@ def test_rgba_antialias():
13761376
# alternating red and blue stripes become purple
13771377
axs[3].imshow(aa, interpolation='antialiased', interpolation_stage='rgba',
13781378
cmap=cmap, vmin=-1.2, vmax=1.2)
1379+
1380+
1381+
@check_figures_equal(extensions=["png"])
1382+
def test_str_norms(fig_test, fig_ref):
1383+
t = np.random.rand(10, 10) * .8 + .1 # between 0 and 1
1384+
axts = fig_test.subplots(1, 5)
1385+
axts[0].imshow(t, norm="log")
1386+
axts[1].imshow(t, norm="log", vmin=.2)
1387+
axts[2].imshow(t, norm="symlog")
1388+
axts[3].imshow(t, norm="symlog", vmin=.3, vmax=.7)
1389+
axts[4].imshow(t, norm="logit", vmin=.3, vmax=.7)
1390+
axrs = fig_ref.subplots(1, 5)
1391+
axrs[0].imshow(t, norm=colors.LogNorm())
1392+
axrs[1].imshow(t, norm=colors.LogNorm(vmin=.2))
1393+
# same linthresh as SymmetricalLogScale's default.
1394+
axrs[2].imshow(t, norm=colors.SymLogNorm(linthresh=2))
1395+
axrs[3].imshow(t, norm=colors.SymLogNorm(linthresh=2, vmin=.3, vmax=.7))
1396+
axrs[4].imshow(t, norm="logit", clim=(.3, .7))
1397+
1398+
assert type(axts[0].images[0].norm) == colors.LogNorm # Exactly that class
1399+
with pytest.raises(ValueError):
1400+
axts[0].imshow(t, norm="foobar")

0 commit comments

Comments
 (0)
0