diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py
index 30b2b6a4dfb4..8cdeaa69bbc3 100644
--- a/lib/matplotlib/axes/_axes.py
+++ b/lib/matplotlib/axes/_axes.py
@@ -708,10 +708,8 @@ def axhline(self, y=0, xmin=0, xmax=1, **kwargs):
                              "argument; axhline generates its own transform.")
         ymin, ymax = self.get_ybound()
 
-        # We need to strip away the units for comparison with
-        # non-unitized bounds
-        self._process_unit_info(ydata=y, kwargs=kwargs)
-        yy = self.convert_yunits(y)
+        # Strip away the units for comparison with non-unitized bounds.
+        yy, = self._process_unit_info([("y", y)], kwargs)
         scaley = (yy < ymin) or (yy > ymax)
 
         trans = self.get_yaxis_transform(which='grid')
@@ -777,10 +775,8 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs):
                              "argument; axvline generates its own transform.")
         xmin, xmax = self.get_xbound()
 
-        # We need to strip away the units for comparison with
-        # non-unitized bounds
-        self._process_unit_info(xdata=x, kwargs=kwargs)
-        xx = self.convert_xunits(x)
+        # Strip away the units for comparison with non-unitized bounds.
+        xx, = self._process_unit_info([("x", x)], kwargs)
         scalex = (xx < xmin) or (xx > xmax)
 
         trans = self.get_xaxis_transform(which='grid')
@@ -917,19 +913,13 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
         --------
         axvspan : Add a vertical span across the axes.
         """
+        # Strip units away.
         self._check_no_units([xmin, xmax], ['xmin', 'xmax'])
-        trans = self.get_yaxis_transform(which='grid')
-
-        # process the unit information
-        self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
-
-        # first we need to strip away the units
-        xmin, xmax = self.convert_xunits([xmin, xmax])
-        ymin, ymax = self.convert_yunits([ymin, ymax])
+        (ymin, ymax), = self._process_unit_info([("y", [ymin, ymax])], kwargs)
 
         verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)
         p = mpatches.Polygon(verts, **kwargs)
-        p.set_transform(trans)
+        p.set_transform(self.get_yaxis_transform(which="grid"))
         self.add_patch(p)
         self._request_autoscale_view(scalex=False)
         return p
@@ -978,19 +968,13 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
         >>> axvspan(1.25, 1.55, facecolor='g', alpha=0.5)
 
         """
+        # Strip units away.
         self._check_no_units([ymin, ymax], ['ymin', 'ymax'])
-        trans = self.get_xaxis_transform(which='grid')
-
-        # process the unit information
-        self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
-
-        # first we need to strip away the units
-        xmin, xmax = self.convert_xunits([xmin, xmax])
-        ymin, ymax = self.convert_yunits([ymin, ymax])
+        (xmin, xmax), = self._process_unit_info([("x", [xmin, xmax])], kwargs)
 
         verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]
         p = mpatches.Polygon(verts, **kwargs)
-        p.set_transform(trans)
+        p.set_transform(self.get_xaxis_transform(which="grid"))
         self.add_patch(p)
         self._request_autoscale_view(scaley=False)
         return p
@@ -1032,11 +1016,8 @@ def hlines(self, y, xmin, xmax, colors=None, linestyles='solid',
         """
 
         # We do the conversion first since not all unitized data is uniform
-        # process the unit information
-        self._process_unit_info([xmin, xmax], y, kwargs=kwargs)
-        y = self.convert_yunits(y)
-        xmin = self.convert_xunits(xmin)
-        xmax = self.convert_xunits(xmax)
+        xmin, xmax, y = self._process_unit_info(
+            [("x", xmin), ("x", xmax), ("y", y)], kwargs)
 
         if not np.iterable(y):
             y = [y]
@@ -1111,12 +1092,9 @@ def vlines(self, x, ymin, ymax, colors=None, linestyles='solid',
         axvline: vertical line across the axes
         """
 
-        self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs)
-
         # We do the conversion first since not all unitized data is uniform
-        x = self.convert_xunits(x)
-        ymin = self.convert_yunits(ymin)
-        ymax = self.convert_yunits(ymax)
+        x, ymin, ymax = self._process_unit_info(
+            [("x", x), ("y", ymin), ("y", ymax)], kwargs)
 
         if not np.iterable(x):
             x = [x]
@@ -1254,14 +1232,9 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1,
         --------
         .. plot:: gallery/lines_bars_and_markers/eventplot_demo.py
         """
-        self._process_unit_info(xdata=positions,
-                                ydata=[lineoffsets, linelengths],
-                                kwargs=kwargs)
-
         # We do the conversion first since not all unitized data is uniform
-        positions = self.convert_xunits(positions)
-        lineoffsets = self.convert_yunits(lineoffsets)
-        linelengths = self.convert_yunits(linelengths)
+        positions, lineoffsets, linelengths = self._process_unit_info(
+            [("x", positions), ("y", lineoffsets), ("y", linelengths)], kwargs)
 
         if not np.iterable(positions):
             positions = [positions]
@@ -2283,11 +2256,13 @@ def bar(self, x, height, width=0.8, bottom=None, *, align="center",
                 x = 0
 
         if orientation == 'vertical':
-            self._process_unit_info(xdata=x, ydata=height, kwargs=kwargs)
+            self._process_unit_info(
+                [("x", x), ("y", height)], kwargs, convert=False)
             if log:
                 self.set_yscale('log', nonpositive='clip')
         elif orientation == 'horizontal':
-            self._process_unit_info(xdata=width, ydata=y, kwargs=kwargs)
+            self._process_unit_info(
+                [("x", width), ("y", y)], kwargs, convert=False)
             if log:
                 self.set_xscale('log', nonpositive='clip')
 
@@ -2567,9 +2542,8 @@ def broken_barh(self, xranges, yrange, **kwargs):
             ydata = cbook.safe_first_element(yrange)
         else:
             ydata = None
-        self._process_unit_info(xdata=xdata,
-                                ydata=ydata,
-                                kwargs=kwargs)
+        self._process_unit_info(
+            [("x", xdata), ("y", ydata)], kwargs, convert=False)
         xranges_conv = []
         for xr in xranges:
             if len(xr) != 2:
@@ -2689,13 +2663,9 @@ def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0,
             locs, heads, *args = args
 
         if orientation == 'vertical':
-            self._process_unit_info(xdata=locs, ydata=heads)
-            locs = self.convert_xunits(locs)
-            heads = self.convert_yunits(heads)
+            locs, heads = self._process_unit_info([("x", locs), ("y", heads)])
         else:
-            self._process_unit_info(xdata=heads, ydata=locs)
-            heads = self.convert_xunits(heads)
-            locs = self.convert_yunits(locs)
+            heads, locs = self._process_unit_info([("x", heads), ("y", locs)])
 
         # defaults for formats
         if linefmt is None:
@@ -3179,7 +3149,7 @@ def errorbar(self, x, y, yerr=None, xerr=None,
         if int(offset) != offset:
             raise ValueError("errorevery's starting index must be an integer")
 
-        self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)
+        self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)
 
         # Make sure all the args are iterable; use lists not arrays to preserve
         # units.
@@ -4346,9 +4316,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
         """
         # Process **kwargs to handle aliases, conflicts with explicit kwargs:
 
-        self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)
-        x = self.convert_xunits(x)
-        y = self.convert_yunits(y)
+        x, y = self._process_unit_info([("x", x), ("y", y)], kwargs)
 
         # np.ma.ravel yields an ndarray, not a masked array,
         # unless its argument is a masked array.
@@ -4577,7 +4545,7 @@ def reduce_C_function(C: array) -> float
             %(PolyCollection)s
 
         """
-        self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)
+        self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)
 
         x, y, C = cbook.delete_masked_points(x, y, C)
 
@@ -4926,9 +4894,7 @@ def quiverkey(self, Q, X, Y, U, label, **kw):
     def _quiver_units(self, args, kw):
         if len(args) > 3:
             x, y = args[0:2]
-            self._process_unit_info(xdata=x, ydata=y, kwargs=kw)
-            x = self.convert_xunits(x)
-            y = self.convert_yunits(y)
+            x, y = self._process_unit_info([("x", x), ("y", y)], kw)
             return (x, y) + args[2:]
         return args
 
@@ -5114,17 +5080,9 @@ def _fill_between_x_or_y(
                     self._get_patches_for_fill.get_next_color()
 
         # Handle united data, such as dates
-        self._process_unit_info(
-            **{f"{ind_dir}data": ind, f"{dep_dir}data": dep1}, kwargs=kwargs)
-        self._process_unit_info(
-            **{f"{dep_dir}data": dep2})
-
-        # Convert the arrays so we can work with them
-        ind = ma.masked_invalid(getattr(self, f"convert_{ind_dir}units")(ind))
-        dep1 = ma.masked_invalid(
-            getattr(self, f"convert_{dep_dir}units")(dep1))
-        dep2 = ma.masked_invalid(
-            getattr(self, f"convert_{dep_dir}units")(dep2))
+        ind, dep1, dep2 = map(
+            ma.masked_invalid, self._process_unit_info(
+                [(ind_dir, ind), (dep_dir, dep1), (dep_dir, dep2)], kwargs))
 
         for name, array in [
                 (ind_dir, ind), (f"{dep_dir}1", dep1), (f"{dep_dir}2", dep2)]:
@@ -5739,9 +5697,7 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None,
         Ny, Nx = X.shape
 
         # unit conversion allows e.g. datetime objects as axis values
-        self._process_unit_info(xdata=X, ydata=Y, kwargs=kwargs)
-        X = self.convert_xunits(X)
-        Y = self.convert_yunits(Y)
+        X, Y = self._process_unit_info([("x", X), ("y", Y)], kwargs)
 
         # convert to MA, if necessary.
         C = ma.asarray(C)
@@ -6016,9 +5972,7 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None,
         X = X.ravel()
         Y = Y.ravel()
         # unit conversion allows e.g. datetime objects as axis values
-        self._process_unit_info(xdata=X, ydata=Y, kwargs=kwargs)
-        X = self.convert_xunits(X)
-        Y = self.convert_yunits(Y)
+        X, Y = self._process_unit_info([("x", X), ("y", Y)], kwargs)
 
         # convert to one dimensional arrays
         C = C.ravel()
@@ -6497,16 +6451,23 @@ def hist(self, x, bins=None, range=None, density=False, weights=None,
         x = cbook._reshape_2D(x, 'x')
         nx = len(x)  # number of datasets
 
-        # Process unit information
-        # Unit conversion is done individually on each dataset
-        self._process_unit_info(xdata=x[0], kwargs=kwargs)
-        x = [self.convert_xunits(xi) for xi in x]
+        # Process unit information.  _process_unit_info sets the unit and
+        # converts the first dataset; then we convert each following dataset
+        # one at a time.
+        if orientation == "vertical":
+            convert_units = self.convert_xunits
+            x = [*self._process_unit_info([("x", x[0])], kwargs),
+                 *map(convert_units, x[1:])]
+        else:  # horizontal
+            convert_units = self.convert_yunits
+            x = [*self._process_unit_info([("y", x[0])], kwargs),
+                 *map(convert_units, x[1:])]
 
         if bin_range is not None:
-            bin_range = self.convert_xunits(bin_range)
+            bin_range = convert_units(bin_range)
 
         if not cbook.is_scalar_or_string(bins):
-            bins = self.convert_xunits(bins)
+            bins = convert_units(bins)
 
         # We need to do to 'weights' what was done to 'x'
         if weights is not None:
@@ -6787,9 +6748,8 @@ def stairs(self, values, edges=None, *,
         if edges is None:
             edges = np.arange(len(values) + 1)
 
-        self._process_unit_info(xdata=edges, ydata=values, kwargs=kwargs)
-        edges = self.convert_xunits(edges)
-        values = self.convert_yunits(values)
+        edges, values = self._process_unit_info(
+            [("x", edges), ("y", values)], kwargs)
 
         patch = mpatches.StepPatch(values,
                                    edges,
diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py
index 40a4af4630b0..02f3475bbfd1 100644
--- a/lib/matplotlib/axes/_base.py
+++ b/lib/matplotlib/axes/_base.py
@@ -2208,47 +2208,64 @@ def update_datalim_bounds(self, bounds):
         """
         self.dataLim.set(mtransforms.Bbox.union([self.dataLim, bounds]))
 
-    def _process_unit_info(self, xdata=None, ydata=None, kwargs=None):
+    def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
         """
-        Look for unit *kwargs* and update the axis instances as necessary
+        Set axis units based on *datasets* and *kwargs*, and optionally apply
+        unit conversions to *datasets*.
 
+        Parameters
+        ----------
+        datasets : list
+            List of (axis_name, dataset) pairs (where the axis name is defined
+            as in `._get_axis_map`.
+        kwargs : dict
+            Other parameters from which unit info (i.e., the *xunits*,
+            *yunits*, *zunits* (for 3D axes), *runits* and *thetaunits* (for
+            polar axes) entries) is popped, if present.  Note that this dict is
+            mutated in-place!
+        convert : bool, default: True
+            Whether to return the original datasets or the converted ones.
 
-        .. warning ::
-
-           This method may mutate the dictionary passed in an kwargs and
-           the Axis instances attached to this Axes.
-        """
-
-        def _process_single_axis(data, axis, unit_name, kwargs):
-            # Return if there's no axis set
+        Returns
+        -------
+        list
+            Either the original datasets if *convert* is False, or the
+            converted ones if *convert* is True (the default).
+        """
+        # The API makes datasets a list of pairs rather than an axis_name to
+        # dataset mapping because it is sometimes necessary to process multiple
+        # datasets for a single axis, and concatenating them may be tricky
+        # (e.g. if some are scalars, etc.).
+        datasets = datasets or []
+        kwargs = kwargs or {}
+        axis_map = self._get_axis_map()
+        for axis_name, data in datasets:
+            try:
+                axis = axis_map[axis_name]
+            except KeyError:
+                raise ValueError(f"Invalid axis name: {axis_name!r}") from None
+            # Update from data if axis is already set but no unit is set yet.
+            if axis is not None and data is not None and not axis.have_units():
+                axis.update_units(data)
+        for axis_name, axis in axis_map.items():
+            # Return if no axis is set.
             if axis is None:
-                return kwargs
-
-            if data is not None:
-                # We only need to update if there is nothing set yet.
-                if not axis.have_units():
-                    axis.update_units(data)
-
-            # Check for units in the kwargs, and if present update axis
-            if kwargs is not None:
-                units = kwargs.pop(unit_name, axis.units)
-                if self.name == 'polar':
-                    # handle special casing to allow the kwargs
-                    # thetaunits and runits to be used with polar
-                    polar_units = {'xunits': 'thetaunits', 'yunits': 'runits'}
-                    units = kwargs.pop(polar_units[unit_name], units)
-
-                if units != axis.units and units is not None:
-                    axis.set_units(units)
-                    # If the units being set imply a different converter,
-                    # we need to update.
-                    if data is not None:
+                continue
+            # Check for units in the kwargs, and if present update axis.
+            units = kwargs.pop(f"{axis_name}units", axis.units)
+            if self.name == "polar":
+                # Special case: polar supports "thetaunits"/"runits".
+                polar_units = {"x": "thetaunits", "y": "runits"}
+                units = kwargs.pop(polar_units[axis_name], units)
+            if units != axis.units and units is not None:
+                axis.set_units(units)
+                # If the units being set imply a different converter,
+                # we need to update again.
+                for dataset_axis_name, data in datasets:
+                    if dataset_axis_name == axis_name and data is not None:
                         axis.update_units(data)
-            return kwargs
-
-        kwargs = _process_single_axis(xdata, self.xaxis, 'xunits', kwargs)
-        kwargs = _process_single_axis(ydata, self.yaxis, 'yunits', kwargs)
-        return kwargs
+        return [axis_map[axis_name].convert_units(data) if convert else data
+                for axis_name, data in datasets]
 
     def in_axes(self, mouseevent):
         """
@@ -3400,7 +3417,7 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False,
                 raise TypeError('Cannot pass both `xmax` and `right`')
             right = xmax
 
-        self._process_unit_info(xdata=(left, right))
+        self._process_unit_info([("x", (left, right))], convert=False)
         left = self._validate_converted_limits(left, self.convert_xunits)
         right = self._validate_converted_limits(right, self.convert_xunits)
 
@@ -3725,7 +3742,7 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
                 raise TypeError('Cannot pass both `ymax` and `top`')
             top = ymax
 
-        self._process_unit_info(ydata=(bottom, top))
+        self._process_unit_info([("y", (bottom, top))], convert=False)
         bottom = self._validate_converted_limits(bottom, self.convert_yunits)
         top = self._validate_converted_limits(top, self.convert_yunits)
 
diff --git a/lib/matplotlib/contour.py b/lib/matplotlib/contour.py
index 015ded0ebdc7..fb8eba585d65 100644
--- a/lib/matplotlib/contour.py
+++ b/lib/matplotlib/contour.py
@@ -1475,9 +1475,7 @@ def _check_xyz(self, args, kwargs):
         convert them to 2D using meshgrid.
         """
         x, y = args[:2]
-        kwargs = self.axes._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)
-        x = self.axes.convert_xunits(x)
-        y = self.axes.convert_yunits(y)
+        x, y = self.axes._process_unit_info([("x", x), ("y", y)], kwargs)
 
         x = np.asarray(x, dtype=np.float64)
         y = np.asarray(y, dtype=np.float64)
diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py
index 5ebec4ba61b8..3864af2bab2c 100644
--- a/lib/mpl_toolkits/mplot3d/axes3d.py
+++ b/lib/mpl_toolkits/mplot3d/axes3d.py
@@ -162,29 +162,6 @@ def convert_zunits(self, z):
         """
         return self.zaxis.convert_units(z)
 
-    def _process_unit_info(self, xdata=None, ydata=None, zdata=None,
-                           kwargs=None):
-        """Update the axis instances based on unit *kwargs* if given."""
-        super()._process_unit_info(xdata=xdata, ydata=ydata, kwargs=kwargs)
-
-        if self.xaxis is None or self.yaxis is None or self.zaxis is None:
-            return
-
-        if zdata is not None:
-            # we only need to update if there is nothing set yet.
-            if not self.zaxis.have_units():
-                self.zaxis.update_units(xdata)
-
-        # process kwargs 2nd since these will override default units
-        if kwargs is not None:
-            zunits = kwargs.pop('zunits', self.zaxis.units)
-            if zunits != self.zaxis.units:
-                self.zaxis.set_units(zunits)
-                # If the units being set imply a different converter,
-                # we need to update.
-                if zdata is not None:
-                    self.zaxis.update_units(zdata)
-
     def set_top_view(self):
         # this happens to be the right view for the viewing coordinates
         # moved up and to the left slightly to fit labels and axes
@@ -747,7 +724,7 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False,
                 raise TypeError('Cannot pass both `xmax` and `right`')
             right = xmax
 
-        self._process_unit_info(xdata=(left, right))
+        self._process_unit_info([("x", (left, right))], convert=False)
         left = self._validate_converted_limits(left, self.convert_xunits)
         right = self._validate_converted_limits(right, self.convert_xunits)
 
@@ -801,7 +778,7 @@ def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False,
                 raise TypeError('Cannot pass both `ymax` and `top`')
             top = ymax
 
-        self._process_unit_info(ydata=(bottom, top))
+        self._process_unit_info([("y", (bottom, top))], convert=False)
         bottom = self._validate_converted_limits(bottom, self.convert_yunits)
         top = self._validate_converted_limits(top, self.convert_yunits)
 
@@ -856,7 +833,7 @@ def set_zlim3d(self, bottom=None, top=None, emit=True, auto=False,
                 raise TypeError('Cannot pass both `zmax` and `top`')
             top = zmax
 
-        self._process_unit_info(zdata=(bottom, top))
+        self._process_unit_info([("z", (bottom, top))], convert=False)
         bottom = self._validate_converted_limits(bottom, self.convert_zunits)
         top = self._validate_converted_limits(top, self.convert_zunits)