8000 Use axes=self as an extra args in parasite_axes · matplotlib/matplotlib@088b39a · GitHub
[go: up one dir, main page]

Skip to content

Commit 088b39a

Browse files
weiji-lianntzer
authored andcommitted
Use axes=self as an extra args in parasite_axes
Right now parasite_axes just use self._parent_axes._get_lines as self._get_lines, but it can't update the axes unit when there are twin axes. Therefore, we need to provide axes=self as an extra args to handle this. We also need to change the callees to use axes in kwargs when provided. The test creates a plot with twin axes where both axes have units. It then checks whether units are appended correctly on the respective axes. The code base without the modification fails the unit test whereas the modification makes it pass the unit test.
1 parent 449caf7 commit 088b39a

File tree

4 files changed

+62
-24
lines changed

4 files changed

+62
-24
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,7 +1720,7 @@ def plot(self, *args, scalex=True, scaley=True, data=None, **kwargs):
17201720
(``'green'``) or hex strings (``'#008000'``).
17211721
"""
17221722
kwargs = cbook.normalize_kwargs(kwargs, mlines.Line2D)
1723-
lines = [*self._get_lines(*args, data=data, **kwargs)]
1723+
lines = [*self._get_lines(self, *args, data=data, **kwargs)]
17241724
for line in lines:
17251725
self.add_line(line)
17261726
if scalex:
@@ -3578,7 +3578,7 @@ def _upcast_err(err):
35783578
# that would call self._process_unit_info again, and do other indirect
35793579
# data processing.
35803580
(data_line, base_style), = self._get_lines._plot_args(
3581-
(x, y) if fmt == '' else (x, y, fmt), kwargs, return_kwargs=True)
3581+
self, (x, y) if fmt == '' else (x, y, fmt), kwargs, return_kwargs=True)
35823582

35833583
# Do this after creating `data_line` to avoid modifying `base_style`.
35843584
if barsabove:
@@ -5286,7 +5286,7 @@ def fill(self, *args, data=None, **kwargs):
52865286
# For compatibility(!), get aliases from Line2D rather than Patch.
52875287
kwargs = cbook.normalize_kwargs(kwargs, mlines.Line2D)
52885288
# _get_patches_for_fill returns a generator, convert it to a list.
5289-
patches = [*self._get_patches_for_fill(*args, data=data, **kwargs)]
5289+
patches = [*self._get_patches_for_fill(self, *args, data=data, **kwargs)]
52905290
for poly in patches:
52915291
self.add_patch(poly)
52925292
self._request_autoscale_view()

lib/matplotlib/axes/_base.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,14 @@ class _process_plot_var_args:
219219
220220
an arbitrary number of *x*, *y*, *fmt* are allowed
221221
"""
222-
def __init__(self, axes, command='plot'):
223-
self.axes = axes
222+
223+
def __init__(self, command='plot'):
224224
self.command = command
225225
self.set_prop_cycle(None)
226226

227227
def __getstate__(self):
228228
# note: it is not possible to pickle a generator (and thus a cycler).
229-
return {'axes': self.axes, 'command': self.command}
229+
return {'command': self.command}
230230

231231
def __setstate__(self, state):
232232
self.__dict__ = state.copy()
@@ -238,8 +238,8 @@ def set_prop_cycle(self, cycler):
238238
self.prop_cycler = itertools.cycle(cycler)
239239
self._prop_keys = cycler.keys # This should make a copy
240240

241-
def __call__(self, *args, data=None, **kwargs):
242-
self.axes._process_unit_info(kwargs=kwargs)
241+
def __call__(self, axes, *args, data=None, **kwargs):
242+
axes._process_unit_info(kwargs=kwargs)
243243

244244
for pos_only in "xy":
245245
if pos_only in kwargs:
@@ -309,7 +309,7 @@ def __call__(self, *args, data=None, **kwargs):
309309
this += args[0],
310310
args = args[1:]
311311
yield from self. 67ED _plot_args(
312-
this, kwargs, ambiguous_fmt_datakey=ambiguous_fmt_datakey)
312+
axes, this, kwargs, ambiguous_fmt_datakey=ambiguous_fmt_datakey)
313313

314314
def get_next_color(self):
315315
"""Return the next color in the cycle."""
@@ -344,17 +344,17 @@ def _setdefaults(self, defaults, kw):
344344
if kw.get(k, None) is None:
345345
kw[k] = defaults[k]
346346

347-
def _makeline(self, x, y, kw, kwargs):
347+
def _makeline(self, axes, x, y, kw, kwargs):
348348
kw = {**kw, **kwargs} # Don't modify the original kw.
349349
default_dict = self._getdefaults(set(), kw)
350350
self._setdefaults(default_dict, kw)
351351
seg = mlines.Line2D(x, y, **kw)
352352
return seg, kw
353353

354-
def _makefill(self, x, y, kw, kwargs):
354+
def _makefill(self, axes, x, y, kw, kwargs):
355355
# Polygon doesn't directly support unitized inputs.
356-
x = self.axes.convert_xunits(x)
357-
y = self.axes.convert_yunits(y)
356+
x = axes.convert_xunits(x)
357+
y = axes.convert_yunits(y)
358358

359359
kw = kw.copy() # Don't modify the original kw.
360360
kwargs = kwargs.copy()
@@ -403,7 +403,7 @@ def _makefill(self, x, y, kw, kwargs):
403403
seg.set(**kwargs)
404404
return seg, kwargs
405405

406-
def _plot_args(self, tup, kwargs, *,
406+
def _plot_args(self, axes, tup, kwargs, *,
407407
return_kwargs=False, ambiguous_fmt_datakey=False):
408408
"""
409409
Process the arguments of ``plot([x], y, [fmt], **kwargs)`` calls.
@@ -495,10 +495,10 @@ def _plot_args(self, tup, kwargs, *,
495495
else:
496496
x, y = index_of(xy[-1])
497497

498-
if self.axes.xaxis is not None:
499-
self.axes.xaxis.update_units(x)
500-
if self.axes.yaxis is not None:
501-
self.axes.yaxis.update_units(y)
498+
if axes.xaxis is not None:
499+
axes.xaxis.update_units(x)
500+
if axes.yaxis is not None:
501+
axes.yaxis.update_units(y)
502502

503503
if x.shape[0] != y.shape[0]:
504504
raise ValueError(f"x and y must have same first dimension, but "
@@ -534,7 +534,7 @@ def _plot_args(self, tup, kwargs, *,
534534
else:
535535
labels = [label] * n_datasets
536536

537-
result = (make_artist(x[:, j % ncx], y[:, j % ncy], kw,
537+
result = (make_artist(axes, x[:, j % ncx], y[:, j % ncy], kw,
538538
{**kwargs, 'label': label})
539539
for j, label in enumerate(labels))
540540

@@ -1292,8 +1292,8 @@ def __clear(self):
12921292
self._tight = None
12931293
self._use_sticky_edges = True
12941294

1295-
self._get_lines = _process_plot_var_args(self)
1296-
self._get_patches_for_fill = _process_plot_var_args(self, 'fill')
1295+
self._get_lines = _process_plot_var_args()
1296+
self._get_patches_for_fill = _process_plot_var_args('fill')
12971297

12981298
self._gridOn = mpl.rcParams['axes.grid']
12991299
old_children, self._children = self._children, []

lib/mpl_toolkits/axes_grid1/tests/test_axes_grid1.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import matplotlib as mpl
55
import matplotlib.pyplot as plt
66
import matplotlib.ticker as mticker
7-
from matplotlib import cbook
7+
from matplotlib import cbook, units
88
from matplotlib.backend_bases import MouseEvent
99
from matplotlib.colors import LogNorm
1010
from matplotlib.patches import Circle, Ellipse
@@ -27,7 +27,6 @@
2727
zoomed_inset_axes, mark_inset, inset_axes, BboxConnectorPatch,
2828
InsetPosition)
2929
import mpl_toolkits.axes_grid1.mpl_axes
30-
3130
import pytest
3231

3332
import numpy as np
@@ -91,6 +90,45 @@ def test_twin_axes_empty_and_removed():
9190
plt.subplots_adjust(wspace=0.5, hspace=1)
9291

9392

93+
def test_twin_axes_both_with_units():
94+
class TestUnit:
95+
def __init__(self, val):
96+
self._val = val
97+
98+
class UnitA(TestUnit):
99+
fmt = "%0.1f Unit A"
100+
class UnitB(TestUnit):
101+
fmt = "%0.1f Unit B"
102+
103+
class UnitConverter(units.ConversionInterface):
104+
@staticmethod
105+
def convert(value, unit, axis):
106+
return [x._val for x in value]
107+
108+
@staticmethod
109+
def axisinfo(unit, axis):
110+
return units.AxisInfo(majfmt=mticker.FormatStrFormatter(unit.fmt))
111+
112+
@staticmethod
113+
def default_units(x, axis):
114+
return x[0].__class__
115+
116+
units.registry[UnitA] = UnitConverter()
117+
units.registry[UnitB] = UnitConverter()
118+
119+
host = host_subplot(111)
120+
host.plot([0, 1, 2], [UnitA(x) for x in (0, 1, 2)])
121+
122+
twin = host.twinx()
123+
twin.axis["right"].major_ticklabels.set_visible(True)
124+
twin.plot([0, 1, 2], [UnitB(x) for x in (0, 2, 2)])
125+
126+
host_labels = [l.get_text() for l in host.get_yticklabels()]
127+
twin_labels = [l.get_text() for l in twin.get_yticklabels()]
128+
assert host_labels == [f"{y} Unit A" for y in [0.0, 0.5, 1.0, 1.5, 2.0]]
129+
assert twin_labels == [f"{y} Unit B" for y in [0.0, 0.5, 1.0, 1.5, 2.0]]
130+
131+
94132
def test_axesgrid_colorbar_log_smoketest():
95133
fig = plt.figure()
96134
grid = AxesGrid(fig, 111, # modified to be only subplot

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3005,7 +3005,7 @@ def errorbar(self, x, y, z, zerr=None, yerr=None, xerr=None, fmt='',
30053005
# that would call self._process_unit_info again, and do other indirect
30063006
# data processing.
30073007
(data_line, base_style), = self._get_lines._plot_args(
3008-
(x, y) if fmt == '' else (x, y, fmt), kwargs, return_kwargs=True)
3008+
self, (x, y) if fmt == '' else (x, y, fmt), kwargs, return_kwargs=True)
30093009
art3d.line_2d_to_3d(data_line, zs=z)
30103010

30113011
# Do this after creating `data_line` to avoid modifying `base_style`.

0 commit comments

Comments
 (0)
0