8000 Merge pull request #21558 from anntzer/sp · matplotlib/matplotlib@f0632c0 · GitHub
[go: up one dir, main page]

Skip to content

Commit f0632c0

Browse files
authored
Merge pull request #21558 from anntzer/sp
Various small fixes for streamplot().
2 parents b525983 + cf204e5 commit f0632c0

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

lib/matplotlib/streamplot.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
116116
if use_multicolor_lines:
117117
if color.shape != grid.shape:
118118
raise ValueError("If 'color' is given, it must match the shape of "
119-
"'Grid(x, y)'")
120-
line_colors = []
119+
"the (x, y) grid")
120+
line_colors = [[]] # Empty entry allows concatenation of zero arrays.
121121
color = np.ma.masked_invalid(color)
122122
else:
123123
line_kw['color'] = color
@@ -126,7 +126,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
126126
if isinstance(linewidth, np.ndarray):
127127
if linewidth.shape != grid.shape:
128128
raise ValueError("If 'linewidth' is given, it must match the "
129-
"shape of 'Grid(x, y)'")
129+
"shape of the (x, y) grid")
130130
line_kw['linewidth'] = []
131131
else:
132132
line_kw['linewidth'] = linewidth
@@ -137,7 +137,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
137137

138138
# Sanity checks.
139139
if u.shape != grid.shape or v.shape != grid.shape:
140-
raise ValueError("'u' and 'v' must match the shape of 'Grid(x, y)'")
140+
raise ValueError("'u' and 'v' must match the shape of the (x, y) grid")
141141

142142
u = np.ma.masked_invalid(u)
143143
v = np.ma.masked_invalid(v)
@@ -310,21 +310,22 @@ class Grid:
310310
"""Grid of data."""
311311
def __init__(self, x, y):
312312

313-
if x.ndim == 1:
313+
if np.ndim(x) == 1:
314314
pass
315-
elif x.ndim == 2:
316-
x_row = x[0, :]
315+
elif np.ndim(x) == 2:
316+
x_row = x[0]
317317
if not np.allclose(x_row, x):
318318
raise ValueError("The rows of 'x' must be equal")
319319
x = x_row
320320
else:
321321
raise ValueError("'x' can have at maximum 2 dimensions")
322322

323-
if y.ndim == 1:
323+
if np.ndim(y) == 1:
324324
pass
325-
elif y.ndim == 2:
326-
y_col = y[:, 0]
327-
if not np.allclose(y_col, y.T):
325+
elif np.ndim(y) == 2:
326+
yt = np.transpose(y) # Also works for nested lists.
327+
y_col = yt[0]
328+
if not np.allclose(y_col, yt):
328329
raise ValueError("The columns of 'y' must be equal")
329330
y = y_col
330331
else:

lib/matplotlib/tests/test_streamplot.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,13 @@ def test_streamplot_grid():
144144

145145
with pytest.raises(ValueError, match="'y' must be strictly increasing"):
146146
plt.streamplot(x, y, u, v)
147+
148+
149+
def test_streamplot_inputs(): # test no exception occurs.
150+
# fully-masked
151+
plt.streamplot(np.arange(3), np.arange(3),
152+
np.full((3, 3), np.nan), np.full((3, 3), np.nan),
153+
color=np.random.rand(3, 3))
154+
# array-likes
155+
plt.streamplot(range(3), range(3),
156+
np.random.rand(3, 3), np.random.rand(3, 3))

0 commit comments

Comments
 (0)
0