diff --git a/doc/api/next_api_changes/behavior/27721-DS.rst b/doc/api/next_api_changes/behavior/27721-DS.rst new file mode 100644 index 000000000000..b451ea30dc58 --- /dev/null +++ b/doc/api/next_api_changes/behavior/27721-DS.rst @@ -0,0 +1,20 @@ +Unit converters can now support units in images +----------------------------------------------- + +`~.cm.ScalarMappable` can now contain data with units. This adds support for +unit-ful data to be plotted using - `~.axes.Axes.imshow`, `~.axes.Axes.pcolor`, +and `~.axes.Axes.pcolormesh` + +For this to be supported by third-party `~.units.ConversionInterface`, +the `~.units.ConversionInterface.default_units` and +`~.units.ConversionInterface.convert` methods must allow for the *axis* +argument to be ``None``, and `~.units.ConversionInterface.convert` must be able to +convert data of more than one dimension (e.g. when plotting images the data is 2D). + +If a conversion interface raises an error when given ``None`` or 2D data as described +above, this error will be re-raised when a user tries to use one of the newly supported +plotting methods with unit-ful data. + +If you have a custom conversion interface you want to forbid using with image data, the +`~.units.ConversionInterface` methods that accept a ``units`` parameter should raise +a `matplotlib.units.ConversionError` when given ``units=None``. diff --git a/doc/users/next_whats_new/image_units.rst b/doc/users/next_whats_new/image_units.rst new file mode 100644 index 000000000000..bcaad73ded4c --- /dev/null +++ b/doc/users/next_whats_new/image_units.rst @@ -0,0 +1,12 @@ +Unit support for images +----------------------- +This adds support for image data with units has been added to the following plotting +methods: + +- `~.axes.Axes.imshow` +- `~.axes.Axes.pcolor` +- `~.axes.Axes.pcolormesh` + +If the data has units, the ``vmin`` and ``vmax`` units to these methods can also have +units, and if you add a colorbar the ``levels`` argument to ``colorbar`` can also +have units. diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index cc310b7da0d7..2bfc79630983 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -8,6 +8,7 @@ from numpy import ma import matplotlib as mpl +import matplotlib.cm as cm import matplotlib.category # Register category unit converter as side effect. import matplotlib.cbook as cbook import matplotlib.collections as mcoll @@ -5838,12 +5839,28 @@ def imshow(self, X, cmap=None, norm=None, *, aspect=None, self.add_image(im) return im + @staticmethod + def _convert_C_units(C): + """ + Remove any units attached to C, and return the units and converter used to do + the conversion. + """ + sm = cm.ScalarMappable() + C = sm._strip_units(C) + converter = sm._converter + units = sm._units + + C = np.asanyarray(C) + C = cbook.safe_masked_invalid(C, copy=True) + return C, units, converter + def _pcolorargs(self, funcname, *args, shading='auto', **kwargs): # - create X and Y if not present; # - reshape X and Y as needed if they are 1-D; # - check for proper sizes based on `shading` kwarg; # - reset shading if shading='auto' to flat or nearest # depending on size; + # - if C has units, get the converter _valid_shading = ['gouraud', 'nearest', 'flat', 'auto'] try: @@ -5855,7 +5872,7 @@ def _pcolorargs(self, funcname, *args, shading='auto', **kwargs): shading = 'auto' if len(args) == 1: - C = np.asanyarray(args[0]) + C, units, converter = self._convert_C_units(args[0]) nrows, ncols = C.shape[:2] if shading in ['gouraud', 'nearest']: X, Y = np.meshgrid(np.arange(ncols), np.arange(nrows)) @@ -5863,11 +5880,11 @@ def _pcolorargs(self, funcname, *args, shading='auto', **kwargs): X, Y = np.meshgrid(np.arange(ncols + 1), np.arange(nrows + 1)) shading = 'flat' C = cbook.safe_masked_invalid(C, copy=True) - return X, Y, C, shading + return X, Y, C, shading, units, converter if len(args) == 3: # Check x and y for bad data... - C = np.asanyarray(args[2]) + C, units, converter = self._convert_C_units(args[2]) # unit conversion allows e.g. datetime objects as axis values X, Y = args[:2] X, Y = self._process_unit_info([("x", X), ("y", Y)], kwargs) @@ -5948,7 +5965,7 @@ def _interp_grid(X): shading = 'flat' C = cbook.safe_masked_invalid(C, copy=True) - return X, Y, C, shading + return X, Y, C, shading, units, converter @_preprocess_data() @_docstring.dedent_interpd @@ -6100,8 +6117,9 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None, if shading is None: shading = mpl.rcParams['pcolor.shading'] shading = shading.lower() - X, Y, C, shading = self._pcolorargs('pcolor', *args, shading=shading, - kwargs=kwargs) + X, Y, C, shading, units, converter = self._pcolorargs( + 'pcolor', *args, shading=shading, kwargs=kwargs + ) linewidths = (0.25,) if 'linewidth' in kwargs: kwargs['linewidths'] = kwargs.pop('linewidth') @@ -6137,6 +6155,8 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None, collection = mcoll.PolyQuadMesh( coords, array=C, cmap=cmap, norm=norm, alpha=alpha, **kwargs) + collection._units = units + collection._converter = converter collection._scale_norm(norm, vmin, vmax) # Transform from native to data coordinates? @@ -6356,8 +6376,9 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None, shading = shading.lower() kwargs.setdefault('edgecolors', 'none') - X, Y, C, shading = self._pcolorargs('pcolormesh', *args, - shading=shading, kwargs=kwargs) + X, Y, C, shading, units, converter = self._pcolorargs( + 'pcolormesh', *args, shading=shading, kwargs=kwargs + ) coords = np.stack([X, Y], axis=-1) kwargs.setdefault('snap', mpl.rcParams['pcolormesh.snap']) @@ -6365,6 +6386,8 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None, collection = mcoll.QuadMesh( coords, antialiased=antialiased, shading=shading, array=C, cmap=cmap, norm=norm, alpha=alpha, **kwargs) + collection._units = units + collection._converter = converter collection._scale_norm(norm, vmin, vmax) coords = coords.reshape(-1, 2) # flatten the grid structure; keep x, y diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index 4ac2379ea5f5..4492770a84e1 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -101,6 +101,8 @@ def default_units(data, axis): object storing string to integer mapping """ # the conversion call stack is default_units -> axis_info -> convert + if axis is None: + return UnitData(data) if axis.units is None: axis.set_units(UnitData(data)) else: @@ -208,7 +210,7 @@ def update(self, data): TypeError If elements in *data* are neither str nor bytes. """ - data = np.atleast_1d(np.array(data, dtype=object)) + data = np.atleast_1d(np.array(data, dtype=object).ravel()) # check if convertible to number: convertible = True for val in OrderedDict.fromkeys(data): diff --git a/lib/matplotlib/cm.py b/lib/matplotlib/cm.py index c14973560ac3..b9b18dc28594 100644 --- a/lib/matplotlib/cm.py +++ b/lib/matplotlib/cm.py @@ -24,6 +24,7 @@ from matplotlib import _api, colors, cbook, scale from matplotlib._cm import datad from matplotlib._cm_listed import cmaps as cmaps_listed +import matplotlib.units as munits _LUTSIZE = mpl.rcParams['image.lut'] @@ -283,6 +284,8 @@ def __init__(self, norm=None, cmap=None): The colormap used to map normalized data values to RGBA colors. """ self._A = None + self._units = None + self._converter = None self._norm = None # So that the setter knows we're initializing. self.set_norm(norm) # The Normalize instance of this ScalarMappable. self.cmap = None # So that the setter knows we're initializing. @@ -393,6 +396,41 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True): rgba = self.cmap(x, alpha=alpha, bytes=bytes) return rgba + def _strip_units(self, A): + """ + Remove units from A, and save the units and converter used to do the conversion. + """ + self._converter = munits.registry.get_converter(A) + if self._converter is None: + self._units = None + return A + + try: + self._units = self._converter.default_units(A, None) + except Exception as e: + if isinstance(e, munits.ConversionError): + raise e + + raise RuntimeError( + f'{self._converter} failed when trying to return the default units for ' + 'this image. This may be because support has not been ' + 'implemented for `axis=None` in the default_units() method.' + ) from e + + try: + A = self._converter.convert(A, self._units, None) + except Exception as e: + if isinstance(e, munits.ConversionError): + raise e + + raise munits.ConversionError( + f'{self._converter} failed when trying to convert the units for this ' + 'image. This may be because support has not been implemented ' + 'for `axis=None` in the convert() method.' + ) from e + + return A + def set_array(self, A): """ Set the value array from array-like *A*. @@ -408,7 +446,7 @@ def set_array(self, A): if A is None: self._A = None return - + A = self._strip_units(A) A = cbook.safe_masked_invalid(A, copy=True) if not np.can_cast(A.dtype, float, "same_kind"): raise TypeError(f"Image data of dtype {A.dtype} cannot be " @@ -458,10 +496,14 @@ def set_clim(self, vmin=None, vmax=None): vmin, vmax = vmin except (TypeError, ValueError): pass - if vmin is not None: - self.norm.vmin = colors._sanitize_extrema(vmin) - if vmax is not None: - self.norm.vmax = colors._sanitize_extrema(vmax) + + def _process_lim(lim): + if self._converter is not None: + lim = self._converter.convert(lim, self._units, axis=None) + return colors._sanitize_extrema(lim) + + self.norm.vmin = _process_lim(vmin) + self.norm.vmax = _process_lim(vmax) def get_alpha(self): """ diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index fb137cc503e1..0796b926f8f1 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -1751,7 +1751,7 @@ def __init__(self, widths, heights, angles, *, units='points', **kwargs): self._widths = 0.5 * np.asarray(widths).ravel() self._heights = 0.5 * np.asarray(heights).ravel() self._angles = np.deg2rad(angles).ravel() - self._units = units + self._length_units = units self.set_transform(transforms.IdentityTransform()) self._transforms = np.empty((0, 3, 3)) self._paths = [mpath.Path.unit_circle()] @@ -1762,24 +1762,24 @@ def _set_transforms(self): ax = self.axes fig = self.figure - if self._units == 'xy': + if self._length_units == 'xy': sc = 1 - elif self._units == 'x': + elif self._length_units == 'x': sc = ax.bbox.width / ax.viewLim.width - elif self._units == 'y': + elif self._length_units == 'y': sc = ax.bbox.height / ax.viewLim.height - elif self._units == 'inches': + elif self._length_units == 'inches': sc = fig.dpi - elif self._units == 'points': + elif self._length_units == 'points': sc = fig.dpi / 72.0 - elif self._units == 'width': + elif self._length_units == 'width': sc = ax.bbox.width - elif self._units == 'height': + elif self._length_units == 'height': sc = ax.bbox.height - elif self._units == 'dots': + elif self._length_units == 'dots': sc = 1.0 else: - raise ValueError(f'Unrecognized units: {self._units!r}') + raise ValueError(f'Unrecognized units: {self._length_units!r}') self._transforms = np.zeros((len(self._widths), 3, 3)) widths = self._widths * sc @@ -1793,7 +1793,7 @@ def _set_transforms(self): self._transforms[:, 2, 2] = 1.0 _affine = transforms.Affine2D - if self._units == 'xy': + if self._length_units == 'xy': m = ax.transData.get_affine().get_matrix().copy() m[:2, 2:] = 0 self.set_transform(_affine(m)) diff --git a/lib/matplotlib/colorbar.py b/lib/matplotlib/colorbar.py index af61e4671ff4..dfa046656191 100644 --- a/lib/matplotlib/colorbar.py +++ b/lib/matplotlib/colorbar.py @@ -291,7 +291,7 @@ def __init__(self, ax, mappable=None, *, cmap=None, drawedges=False, extendfrac=None, extendrect=False, - label='', + label=None, location=None, ): @@ -391,12 +391,17 @@ def __init__(self, ax, mappable=None, *, cmap=None, orientation) if location is None else location self.ticklocation = ticklocation - self.set_label(label) + if label is not None: + self.set_label(label) self._reset_locator_formatter_scale() + self._set_units_from_mappable() + + if ticks is not None and self._converter is not None: + ticks = self._converter.convert(ticks, self._units, self._long_axis()) if np.iterable(ticks): self._locator = ticker.FixedLocator(ticks, nbins=len(ticks)) - else: + elif isinstance(ticks, ticker.Locator): self._locator = ticks if isinstance(format, str): @@ -1331,6 +1336,27 @@ def drag_pan(self, button, key, x, y): elif self.orientation == 'vertical': self.norm.vmin, self.norm.vmax = points[:, 1] + def _set_units_from_mappable(self): + """ + Set the colorbar locator and formatter if the mappable has units. + """ + self._units = self.mappable._units + self._converter = self.mappable._converter + if self._converter is not None: + + axis = self._long_axis() + info = self._converter.axisinfo(self._units, axis) + + if info is not None: + if info.majloc is not None: + self.locator = info.majloc + if info.minloc is not None: + self.minorlocator = info.minloc + if info.majfmt is not None: + self.formatter = info.majfmt + if info.minfmt is not None: + self.minorformatter = info.minfmt + ColorbarBase = Colorbar # Backcompat API diff --git a/lib/matplotlib/colorbar.pyi b/lib/matplotlib/colorbar.pyi index f71c5759fc55..b065214ff445 100644 --- a/lib/matplotlib/colorbar.pyi +++ b/lib/matplotlib/colorbar.pyi @@ -59,7 +59,7 @@ class Colorbar: drawedges: bool = ..., extendfrac: Literal["auto"] | float | Sequence[float] | None = ..., extendrect: bool = ..., - label: str = ..., + label: str | None = ..., location: Literal["left", "right", "top", "bottom"] | None = ... ) -> None: ... @property diff --git a/lib/matplotlib/image.py b/lib/matplotlib/image.py index 73738fe3bdbe..294c38951d18 100644 --- a/lib/matplotlib/image.py +++ b/lib/matplotlib/image.py @@ -726,6 +726,8 @@ def set_data(self, A): """ if isinstance(A, PIL.Image.Image): A = pil_to_array(A) # Needed e.g. to apply png palette. + + A = self._strip_units(A) self._A = self._normalize_image_array(A) self._imcache = None self.stale = True @@ -1140,6 +1142,7 @@ def set_data(self, x, y, A): (M, N) `~numpy.ndarray` or masked array of values to be colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array. """ + A = self._strip_units(A) A = self._normalize_image_array(A) x = np.array(x, np.float32) y = np.array(y, np.float32) @@ -1300,6 +1303,7 @@ def set_data(self, x, y, A): - (M, N, 3): RGB array - (M, N, 4): RGBA array """ + A = self._strip_units(A) A = self._normalize_image_array(A) x = np.arange(0., A.shape[1] + 1) if x is None else np.array(x, float).ravel() y = np.arange(0., A.shape[0] + 1) if y is None else np.array(y, float).ravel() diff --git a/lib/matplotlib/tests/baseline_images/test_units/mappable_units.png b/lib/matplotlib/tests/baseline_images/test_units/mappable_units.png new file mode 100644 index 000000000000..4d1896e9189b Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_units/mappable_units.png differ diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 5baaeaa5d388..aa666ad790a3 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -794,7 +794,7 @@ def test_collection_set_array(): # Test set_array with wrong dtype with pytest.raises(TypeError, match="^Image data of dtype"): - c.set_array("wrong_input") + c.set_array(object()) # Test if array kwarg is copied vals[5] = 45 diff --git a/lib/matplotlib/tests/test_pickle.py b/lib/matplotlib/tests/test_pickle.py index cab412bff561..549bab79ba51 100644 --- a/lib/matplotlib/tests/test_pickle.py +++ b/lib/matplotlib/tests/test_pickle.py @@ -111,7 +111,7 @@ def test_complete(fig_test, fig_ref): loaded.canvas.draw() fig_test.set_size_inches(loaded.get_size_inches()) - fig_test.figimage(loaded.canvas.renderer.buffer_rgba()) + fig_test.figimage(np.asarray(loaded.canvas.renderer.buffer_rgba())) plt.close(loaded) @@ -151,7 +151,7 @@ def test_pickle_load_from_subprocess(fig_test, fig_ref, tmp_path): loaded_fig.canvas.draw() fig_test.set_size_inches(loaded_fig.get_size_inches()) - fig_test.figimage(loaded_fig.canvas.renderer.buffer_rgba()) + fig_test.figimage(np.asarray(loaded_fig.canvas.renderer.buffer_rgba())) plt.close(loaded_fig) diff --git a/lib/matplotlib/tests/test_units.py b/lib/matplotlib/tests/test_units.py index a5fd32dfb3e5..a9c771d3a6dd 100644 --- a/lib/matplotlib/tests/test_units.py +++ b/lib/matplotlib/tests/test_units.py @@ -41,6 +41,9 @@ def __getitem__(self, item): def __array__(self): return np.asarray(self.magnitude) + def __len__(self): + return len(self.__array__()) + @pytest.fixture def quantity_converter(): @@ -302,3 +305,42 @@ def test_plot_kernel(): # just a smoketest that fail kernel = Kernel([1, 2, 3, 4, 5]) plt.plot(kernel) + + +@image_comparison(['mappable_units.png'], style="mpl20") +def test_mappable_units(quantity_converter): + # Check that showing an image with units works + munits.registry[Quantity] = quantity_converter + x, y = np.meshgrid([0, 1], [0, 1]) + data = Quantity(np.arange(4).reshape(2, 2), 'hours') + vmin = Quantity(1, "hours") # Test a limit different from min of the data + vmax = Quantity(3 * 60, "minutes") # Test a different unit to the data + + fig, axs = plt.subplots(nrows=2, ncols=2, constrained_layout=True) + + # imshow + ax = axs[0, 0] + mappable = ax.imshow(data, origin='lower', vmin=vmin, vmax=vmax) + cbar = fig.colorbar(mappable, ax=ax, extend="min") + + # pcolor + # + # Use datetime to check that the locator/formatter is set correctly + # Also test ticks argument to colorbar + data = [[datetime(2024, 1, 1), datetime(2024, 1, 2)], + [datetime(2024, 1, 3), datetime(2024, 1, 4)]] + vmin = datetime(2024, 1, 2) + vmax = datetime(2024, 1, 5) + ax = axs[0, 1] + mappable = ax.pcolor(x, y, data, vmin=vmin, vmax=vmax) + + ticks = [datetime(2024, 1, 2), datetime(2024, 1, 3), datetime(2024, 1, 5)] + fig.colorbar(mappable, ax=ax, extend="min", ticks=ticks) + + # pcolormesh + horizontal colorbar + categorical + data = [["one", "two"], ["three", "four"]] + ax = axs[1, 0] + mappable = ax.pcolormesh(x, y, data) + fig.colorbar(mappable, ax=ax, orientation="horizontal", extend="min") + + axs[1, 1].axis("off") diff --git a/lib/matplotlib/units.py b/lib/matplotlib/units.py index e3480f228bb4..f886eaefe342 100644 --- a/lib/matplotlib/units.py +++ b/lib/matplotlib/units.py @@ -118,16 +118,22 @@ def axisinfo(unit, axis): @staticmethod def default_units(x, axis): - """Return the default unit for *x* or ``None`` for the given axis.""" + """ + Return the default unit for *x*. + + *axis* may be an `~.axis.Axis` or ``None``. + """ return None @staticmethod def convert(obj, unit, axis): """ - Convert *obj* using *unit* for the specified *axis*. + Convert *obj* using *unit*. If *obj* is a sequence, return the converted sequence. The output must be a sequence of scalars that can be used by the numpy array layer. + + *axis* may be an `~.axis.Axis` or ``None``. """ return obj @@ -186,7 +192,7 @@ def get_converter(self, x): else: # ... and avoid infinite recursion for pathological iterables for # which indexing returns instances of the same iterable class. - if type(first) is not type(x): + if isinstance(first, list) or type(first) is not type(x): return self.get_converter(first) return None