diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 30b2b6a4dfb4..8cdeaa69bbc3 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -708,10 +708,8 @@ def axhline(self, y=0, xmin=0, xmax=1, **kwargs): "argument; axhline generates its own transform.") ymin, ymax = self.get_ybound() - # We need to strip away the units for comparison with - # non-unitized bounds - self._process_unit_info(ydata=y, kwargs=kwargs) - yy = self.convert_yunits(y) + # Strip away the units for comparison with non-unitized bounds. + yy, = self._process_unit_info([("y", y)], kwargs) scaley = (yy < ymin) or (yy > ymax) trans = self.get_yaxis_transform(which='grid') @@ -777,10 +775,8 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs): "argument; axvline generates its own transform.") xmin, xmax = self.get_xbound() - # We need to strip away the units for comparison with - # non-unitized bounds - self._process_unit_info(xdata=x, kwargs=kwargs) - xx = self.convert_xunits(x) + # Strip away the units for comparison with non-unitized bounds. + xx, = self._process_unit_info([("x", x)], kwargs) scalex = (xx < xmin) or (xx > xmax) trans = self.get_xaxis_transform(which='grid') @@ -917,19 +913,13 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs): -------- axvspan : Add a vertical span across the axes. """ + # Strip units away. self._check_no_units([xmin, xmax], ['xmin', 'xmax']) - trans = self.get_yaxis_transform(which='grid') - - # process the unit information - self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs) - - # first we need to strip away the units - xmin, xmax = self.convert_xunits([xmin, xmax]) - ymin, ymax = self.convert_yunits([ymin, ymax]) + (ymin, ymax), = self._process_unit_info([("y", [ymin, ymax])], kwargs) verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin) p = mpatches.Polygon(verts, **kwargs) - p.set_transform(trans) + p.set_transform(self.get_yaxis_transform(which="grid")) self.add_patch(p) self._request_autoscale_view(scalex=False) return p @@ -978,19 +968,13 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs): >>> axvspan(1.25, 1.55, facecolor='g', alpha=0.5) """ + # Strip units away. self._check_no_units([ymin, ymax], ['ymin', 'ymax']) - trans = self.get_xaxis_transform(which='grid') - - # process the unit information - self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs) - - # first we need to strip away the units - xmin, xmax = self.convert_xunits([xmin, xmax]) - ymin, ymax = self.convert_yunits([ymin, ymax]) + (xmin, xmax), = self._process_unit_info([("x", [xmin, xmax])], kwargs) verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)] p = mpatches.Polygon(verts, **kwargs) - p.set_transform(trans) + p.set_transform(self.get_xaxis_transform(which="grid")) self.add_patch(p) self._request_autoscale_view(scaley=False) return p @@ -1032,11 +1016,8 @@ def hlines(self, y, xmin, xmax, colors=None, linestyles='solid', """ # We do the conversion first since not all unitized data is uniform - # process the unit information - self._process_unit_info([xmin, xmax], y, kwargs=kwargs) - y = self.convert_yunits(y) - xmin = self.convert_xunits(xmin) - xmax = self.convert_xunits(xmax) + xmin, xmax, y = self._process_unit_info( + [("x", xmin), ("x", xmax), ("y", y)], kwargs) if not np.iterable(y): y = [y] @@ -1111,12 +1092,9 @@ def vlines(self, x, ymin, ymax, colors=None, linestyles='solid', axvline: vertical line across the axes """ - self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs) - # We do the conversion first since not all unitized data is uniform - x = self.convert_xunits(x) - ymin = self.convert_yunits(ymin) - ymax = self.convert_yunits(ymax) + x, ymin, ymax = self._process_unit_info( + [("x", x), ("y", ymin), ("y", ymax)], kwargs) if not np.iterable(x): x = [x] @@ -1254,14 +1232,9 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1, -------- .. plot:: gallery/lines_bars_and_markers/eventplot_demo.py """ - self._process_unit_info(xdata=positions, - ydata=[lineoffsets, linelengths], - kwargs=kwargs) - # We do the conversion first since not all unitized data is uniform - positions = self.convert_xunits(positions) - lineoffsets = self.convert_yunits(lineoffsets) - linelengths = self.convert_yunits(linelengths) + positions, lineoffsets, linelengths = self._process_unit_info( + [("x", positions), ("y", lineoffsets), ("y", linelengths)], kwargs) if not np.iterable(positions): positions = [positions] @@ -2283,11 +2256,13 @@ def bar(self, x, height, width=0.8, bottom=None, *, align="center", x = 0 if orientation == 'vertical': - self._process_unit_info(xdata=x, ydata=height, kwargs=kwargs) + self._process_unit_info( + [("x", x), ("y", height)], kwargs, convert=False) if log: self.set_yscale('log', nonpositive='clip') elif orientation == 'horizontal': - self._process_unit_info(xdata=width, ydata=y, kwargs=kwargs) + self._process_unit_info( + [("x", width), ("y", y)], kwargs, convert=False) if log: self.set_xscale('log', nonpositive='clip') @@ -2567,9 +2542,8 @@ def broken_barh(self, xranges, yrange, **kwargs): ydata = cbook.safe_first_element(yrange) else: ydata = None - self._process_unit_info(xdata=xdata, - ydata=ydata, - kwargs=kwargs) + self._process_unit_info( + [("x", xdata), ("y", ydata)], kwargs, convert=False) xranges_conv = [] for xr in xranges: if len(xr) != 2: @@ -2689,13 +2663,9 @@ def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0, locs, heads, *args = args if orientation == 'vertical': - self._process_unit_info(xdata=locs, ydata=heads) - locs = self.convert_xunits(locs) - heads = self.convert_yunits(heads) + locs, heads = self._process_unit_info([("x", locs), ("y", heads)]) else: - self._process_unit_info(xdata=heads, ydata=locs) - heads = self.convert_xunits(heads) - locs = self.convert_yunits(locs) + heads, locs = self._process_unit_info([("x", heads), ("y", locs)]) # defaults for formats if linefmt is None: @@ -3179,7 +3149,7 @@ def errorbar(self, x, y, yerr=None, xerr=None, if int(offset) != offset: raise ValueError("errorevery's starting index must be an integer") - self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs) + self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False) # Make sure all the args are iterable; use lists not arrays to preserve # units. @@ -4346,9 +4316,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, """ # Process **kwargs to handle aliases, conflicts with explicit kwargs: - self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs) - x = self.convert_xunits(x) - y = self.convert_yunits(y) + x, y = self._process_unit_info([("x", x), ("y", y)], kwargs) # np.ma.ravel yields an ndarray, not a masked array, # unless its argument is a masked array. @@ -4577,7 +4545,7 @@ def reduce_C_function(C: array) -> float %(PolyCollection)s """ - self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs) + self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False) x, y, C = cbook.delete_masked_points(x, y, C) @@ -4926,9 +4894,7 @@ def quiverkey(self, Q, X, Y, U, label, **kw): def _quiver_units(self, args, kw): if len(args) > 3: x, y = args[0:2] - self._process_unit_info(xdata=x, ydata=y, kwargs=kw) - x = self.convert_xunits(x) - y = self.convert_yunits(y) + x, y = self._process_unit_info([("x", x), ("y", y)], kw) return (x, y) + args[2:] return args @@ -5114,17 +5080,9 @@ def _fill_between_x_or_y( self._get_patches_for_fill.get_next_color() # Handle united data, such as dates - self._process_unit_info( - **{f"{ind_dir}data": ind, f"{dep_dir}data": dep1}, kwargs=kwargs) - self._process_unit_info( - **{f"{dep_dir}data": dep2}) - - # Convert the arrays so we can work with them - ind = ma.masked_invalid(getattr(self, f"convert_{ind_dir}units")(ind)) - dep1 = ma.masked_invalid( - getattr(self, f"convert_{dep_dir}units")(dep1)) - dep2 = ma.masked_invalid( - getattr(self, f"convert_{dep_dir}units")(dep2)) + ind, dep1, dep2 = map( + ma.masked_invalid, self._process_unit_info( + [(ind_dir, ind), (dep_dir, dep1), (dep_dir, dep2)], kwargs)) for name, array in [ (ind_dir, ind), (f"{dep_dir}1", dep1), (f"{dep_dir}2", dep2)]: @@ -5739,9 +5697,7 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None, Ny, Nx = X.shape # unit conversion allows e.g. datetime objects as axis values - self._process_unit_info(xdata=X, ydata=Y, kwargs=kwargs) - X = self.convert_xunits(X) - Y = self.convert_yunits(Y) + X, Y = self._process_unit_info([("x", X), ("y", Y)], kwargs) # convert to MA, if necessary. C = ma.asarray(C) @@ -6016,9 +5972,7 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None, X = X.ravel() Y = Y.ravel() # unit conversion allows e.g. datetime objects as axis values - self._process_unit_info(xdata=X, ydata=Y, kwargs=kwargs) - X = self.convert_xunits(X) - Y = self.convert_yunits(Y) + X, Y = self._process_unit_info([("x", X), ("y", Y)], kwargs) # convert to one dimensional arrays C = C.ravel() @@ -6497,16 +6451,23 @@ def hist(self, x, bins=None, range=None, density=False, weights=None, x = cbook._reshape_2D(x, 'x') nx = len(x) # number of datasets - # Process unit information - # Unit conversion is done individually on each dataset - self._process_unit_info(xdata=x[0], kwargs=kwargs) - x = [self.convert_xunits(xi) for xi in x] + # Process unit information. _process_unit_info sets the unit and + # converts the first dataset; then we convert each following dataset + # one at a time. + if orientation == "vertical": + convert_units = self.convert_xunits + x = [*self._process_unit_info([("x", x[0])], kwargs), + *map(convert_units, x[1:])] + else: # horizontal + convert_units = self.convert_yunits + x = [*self._process_unit_info([("y", x[0])], kwargs), + *map(convert_units, x[1:])] if bin_range is not None: - bin_range = self.convert_xunits(bin_range) + bin_range = convert_units(bin_range) if not cbook.is_scalar_or_string(bins): - bins = self.convert_xunits(bins) + bins = convert_units(bins) # We need to do to 'weights' what was done to 'x' if weights is not None: @@ -6787,9 +6748,8 @@ def stairs(self, values, edges=None, *, if edges is None: edges = np.arange(len(values) + 1) - self._process_unit_info(xdata=edges, ydata=values, kwargs=kwargs) - edges = self.convert_xunits(edges) - values = self.convert_yunits(values) + edges, values = self._process_unit_info( + [("x", edges), ("y", values)], kwargs) patch = mpatches.StepPatch(values, edges, diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 40a4af4630b0..02f3475bbfd1 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -2208,47 +2208,64 @@ def update_datalim_bounds(self, bounds): """ self.dataLim.set(mtransforms.Bbox.union([self.dataLim, bounds])) - def _process_unit_info(self, xdata=None, ydata=None, kwargs=None): + def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True): """ - Look for unit *kwargs* and update the axis instances as necessary + Set axis units based on *datasets* and *kwargs*, and optionally apply + unit conversions to *datasets*. + Parameters + ---------- + datasets : list + List of (axis_name, dataset) pairs (where the axis name is defined + as in `._get_axis_map`. + kwargs : dict + Other parameters from which unit info (i.e., the *xunits*, + *yunits*, *zunits* (for 3D axes), *runits* and *thetaunits* (for + polar axes) entries) is popped, if present. Note that this dict is + mutated in-place! + convert : bool, default: True + Whether to return the original datasets or the converted ones. - .. warning :: - - This method may mutate the dictionary passed in an kwargs and - the Axis instances attached to this Axes. - """ - - def _process_single_axis(data, axis, unit_name, kwargs): - # Return if there's no axis set + Returns + ------- + list + Either the original datasets if *convert* is False, or the + converted ones if *convert* is True (the default). + """ + # The API makes datasets a list of pairs rather than an axis_name to + # dataset mapping because it is sometimes necessary to process multiple + # datasets for a single axis, and concatenating them may be tricky + # (e.g. if some are scalars, etc.). + datasets = datasets or [] + kwargs = kwargs or {} + axis_map = self._get_axis_map() + for axis_name, data in datasets: + try: + axis = axis_map[axis_name] + except KeyError: + raise ValueError(f"Invalid axis name: {axis_name!r}") from None + # Update from data if axis is already set but no unit is set yet. + if axis is not None and data is not None and not axis.have_units(): + axis.update_units(data) + for axis_name, axis in axis_map.items(): + # Return if no axis is set. if axis is None: - return kwargs - - if data is not None: - # We only need to update if there is nothing set yet. - if not axis.have_units(): - axis.update_units(data) - - # Check for units in the kwargs, and if present update axis - if kwargs is not None: - units = kwargs.pop(unit_name, axis.units) - if self.name == 'polar': - # handle special casing to allow the kwargs - # thetaunits and runits to be used with polar - polar_units = {'xunits': 'thetaunits', 'yunits': 'runits'} - units = kwargs.pop(polar_units[unit_name], units) - - if units != axis.units and units is not None: - axis.set_units(units) - # If the units being set imply a different converter, - # we need to update. - if data is not None: + continue + # Check for units in the kwargs, and if present update axis. + units = kwargs.pop(f"{axis_name}units", axis.units) + if self.name == "polar": + # Special case: polar supports "thetaunits"/"runits". + polar_units = {"x": "thetaunits", "y": "runits"} + units = kwargs.pop(polar_units[axis_name], units) + if units != axis.units and units is not None: + axis.set_units(units) + # If the units being set imply a different converter, + # we need to update again. + for dataset_axis_name, data in datasets: + if dataset_axis_name == axis_name and data is not None: axis.update_units(data) - return kwargs - - kwargs = _process_single_axis(xdata, self.xaxis, 'xunits', kwargs) - kwargs = _process_single_axis(ydata, self.yaxis, 'yunits', kwargs) - return kwargs + return [axis_map[axis_name].convert_units(data) if convert else data + for axis_name, data in datasets] def in_axes(self, mouseevent): """ @@ -3400,7 +3417,7 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False, raise TypeError('Cannot pass both `xmax` and `right`') right = xmax - self._process_unit_info(xdata=(left, right)) + self._process_unit_info([("x", (left, right))], convert=False) left = self._validate_converted_limits(left, self.convert_xunits) right = self._validate_converted_limits(right, self.convert_xunits) @@ -3725,7 +3742,7 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False, raise TypeError('Cannot pass both `ymax` and `top`') top = ymax - self._process_unit_info(ydata=(bottom, top)) + self._process_unit_info([("y", (bottom, top))], convert=False) bottom = self._validate_converted_limits(bottom, self.convert_yunits) top = self._validate_converted_limits(top, self.convert_yunits) diff --git a/lib/matplotlib/contour.py b/lib/matplotlib/contour.py index 015ded0ebdc7..fb8eba585d65 100644 --- a/lib/matplotlib/contour.py +++ b/lib/matplotlib/contour.py @@ -1475,9 +1475,7 @@ def _check_xyz(self, args, kwargs): convert them to 2D using meshgrid. """ x, y = args[:2] - kwargs = self.axes._process_unit_info(xdata=x, ydata=y, kwargs=kwargs) - x = self.axes.convert_xunits(x) - y = self.axes.convert_yunits(y) + x, y = self.axes._process_unit_info([("x", x), ("y", y)], kwargs) x = np.asarray(x, dtype=np.float64) y = np.asarray(y, dtype=np.float64) diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 5ebec4ba61b8..3864af2bab2c 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -162,29 +162,6 @@ def convert_zunits(self, z): """ return self.zaxis.convert_units(z) - def _process_unit_info(self, xdata=None, ydata=None, zdata=None, - kwargs=None): - """Update the axis instances based on unit *kwargs* if given.""" - super()._process_unit_info(xdata=xdata, ydata=ydata, kwargs=kwargs) - - if self.xaxis is None or self.yaxis is None or self.zaxis is None: - return - - if zdata is not None: - # we only need to update if there is nothing set yet. - if not self.zaxis.have_units(): - self.zaxis.update_units(xdata) - - # process kwargs 2nd since these will override default units - if kwargs is not None: - zunits = kwargs.pop('zunits', self.zaxis.units) - if zunits != self.zaxis.units: - self.zaxis.set_units(zunits) - # If the units being set imply a different converter, - # we need to update. - if zdata is not None: - self.zaxis.update_units(zdata) - def set_top_view(self): # this happens to be the right view for the viewing coordinates # moved up and to the left slightly to fit labels and axes @@ -747,7 +724,7 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False, raise TypeError('Cannot pass both `xmax` and `right`') right = xmax - self._process_unit_info(xdata=(left, right)) + self._process_unit_info([("x", (left, right))], convert=False) left = self._validate_converted_limits(left, self.convert_xunits) right = self._validate_converted_limits(right, self.convert_xunits) @@ -801,7 +778,7 @@ def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False, raise TypeError('Cannot pass both `ymax` and `top`') top = ymax - self._process_unit_info(ydata=(bottom, top)) + self._process_unit_info([("y", (bottom, top))], convert=False) bottom = self._validate_converted_limits(bottom, self.convert_yunits) top = self._validate_converted_limits(top, self.convert_yunits) @@ -856,7 +833,7 @@ def set_zlim3d(self, bottom=None, top=None, emit=True, auto=False, raise TypeError('Cannot pass both `zmax` and `top`') top = zmax - self._process_unit_info(zdata=(bottom, top)) + self._process_unit_info([("z", (bottom, top))], convert=False) bottom = self._validate_converted_limits(bottom, self.convert_zunits) top = self._validate_converted_limits(top, self.convert_zunits)