8000 Enforce that Line data modifications are sequences · matplotlib/matplotlib@cd08d18 · GitHub
[go: up one dir, main page]

Skip to content

Commit cd08d18

Browse files
committed
Enforce that Line data modifications are sequences
When creating a Line2D, x/y data is required to be a sequence. This is not enforced when modifying the data, or with Line3D.
1 parent 80073f6 commit cd08d18

File tree

6 files changed

+45
-8
lines changed

6 files changed

+45
-8
lines changed

lib/matplotlib/lines.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,8 @@ def set_xdata(self, x):
12281228
----------
12291229
x : 1D array
12301230
"""
1231+
if not np.iterable(x):
1232+
raise RuntimeError('x must be a sequence')
12311233
self._xorig = copy.copy(x)
12321234
self._invalidx = True
12331235
self.stale = True
@@ -1240,6 +1242,8 @@ def set_ydata(self, y):
12401242
----------
12411243
y : 1D array
12421244
"""
1245+
if not np.iterable(y):
1246+
raise RuntimeError('y must be a sequence')
12431247
self._yorig = copy.copy(y)
12441248
self._invalidy = True
12451249
self.stale = True

lib/matplotlib/tests/test_lines.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,19 @@ def test_set_line_coll_dash():
8282
ax.contour(np.random.randn(20, 30), linestyles=[(0, (3, 3))])
8383

8484

85+
def test_invalid_line_data():
86+
with pytest.raises(RuntimeError, match='xdata must be'):
87+
mlines.Line2D(0, [])
88+
with pytest.raises(RuntimeError, match='ydata must be'):
89+
mlines.Line2D([], 1)
90+
91+
line = mlines.Line2D([], [])
92+
with pytest.raises(RuntimeError, match='x must be'):
93+
line.set_xdata(0)
94+
with pytest.raises(RuntimeError, match='y must be'):
95+
line.set_ydata(0)
96+
97+
8598
@image_comparison(['line_dashes'], remove_text=True)
8699
def test_line_dashes():
87100
fig, ax = plt.subplots()

lib/matplotlib/tests/test_widgets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ def mean(vmin, vmax):
888888
# Return mean of values in x between *vmin* and *vmax*
889889
indmin, indmax = np.searchsorted(x, (vmin, vmax))
890890
v = values[indmin:indmax].mean()
891-
ln2.set_data(x, v)
891+
ln2.set_data(x, np.full_like(x, v))
892892

893893
span = widgets.SpanSelector(ax, mean, direction='horizontal',
894894
onmove_callback=mean,
@@ -905,7 +905,7 @@ def mean(vmin, vmax):
905905
assert span._get_animated_artists() == (ln, ln2)
906906
assert ln.stale is False
907907
assert ln2.stale
908-
assert ln2.get_ydata() == 0.9547335049088455
908+
assert_allclose(ln2.get_ydata(), 0.9547335049088455)
909909
span.update()
910910
assert ln2.stale is False
911911

@@ -918,7 +918,7 @@ def mean(vmin, vmax):
918918
do_event(span, 'onmove', xdata=move_data[0], ydata=move_data[1], button=1)
919919
assert ln.stale is False
920920
assert ln2.stale
921-
assert ln2.get_ydata() == -0.9424150707548072
921+
assert_allclose(ln2.get_ydata(), -0.9424150707548072)
922922
do_event(span, 'release', xdata=release_data[0],
923923
ydata=release_data[1], button=1)
924924
assert ln2.stale is False

lib/matplotlib/widgets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3311,7 +3311,8 @@ def extents(self, extents):
33113311
# Update displayed handles
33123312
self._corner_handles.set_data(*self.corners)
33133313
self._edge_handles.set_data(*self.edge_centers)
3314-
self._center_handle.set_data(*self.center)
3314+
x, y = self.center
3315+
self._center_handle.set_data([x], [y])
33153316
self.set_visible(self.visible)
33163317
self.update()
33173318

Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(self, xs, ys, zs, *args, **kwargs):
166166
Keyword arguments are passed onto :func:`~matplotlib.lines.Line2D`.
167167
"""
168168
super().__init__([], [], *args, **kwargs)
169-
self._verts3d = xs, ys, zs
169+
self.set_data_3d(xs, ys, zs)
170170

171171
def set_3d_properties(self, zs=0, zdir='z'):
172172
xs = self.get_xdata()
@@ -193,9 +193,11 @@ def set_data_3d(self, *args):
193193
Accepts x, y, z arguments or a single array-like (x, y, z)
194194
"""
195195
if len(args) == 1:
196-
self._verts3d = args[0]
197-
else:
198-
self._verts3d = args
196+
args = args[0]
197+
for name, xyz in zip('xyz', args):
198+
if not np.iterable(xyz):
199+
raise RuntimeError(f'{name} must be a sequence')
200+
self._verts3d = args
199201
self.stale = True
200202

201203
def get_data_3d(self):

lib/mpl_toolkits/tests/test_mplot3d.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,23 @@ def test_plot_scalar(fig_test, fig_ref):
228228
ax2.plot(1, 1, "o")
229229

230230

231+
def test_invalid_line_data():
232+
with pytest.raises(RuntimeError, match='x must be'):
233+
art3d.Line3D(0, [], [])
234+
with pytest.raises(RuntimeError, match='y must be'):
235+
art3d.Line3D([], 0, [])
236+
with pytest.raises(RuntimeError, match='z must be'):
237+
art3d.Line3D([], [], 0)
238+
239+
line = art3d.Line3D([], [], [])
240+
with pytest.raises(RuntimeError, match='x must be'):
241+
line.set_data_3d(0, [], [])
242+
with pytest.raises(RuntimeError, match='y must be'):
243+
line.set_data_3d([], 0, [])
244+
with pytest.raises(RuntimeError, match='z must be'):
245+
line.set_data_3d([], [], 0)
246+
247+
231248
@mpl3d_image_comparison(['mixedsubplot.png'])
232249
def test_mixedsubplots():
233250
def f(t):

0 commit comments

Comments
 (0)
0