@@ -116,8 +116,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
116
116
if use_multicolor_lines :
117
117
if color .shape != grid .shape :
118
118
raise ValueError ("If 'color' is given, it must match the shape of "
119
- "'Grid (x, y)' " )
120
- line_colors = []
119
+ "the (x, y) grid " )
120
+ line_colors = [[]] # Empty entry allows concatenation of zero arrays.
121
121
color = np .ma .masked_invalid (color )
122
122
else :
123
123
line_kw ['color' ] = color
@@ -126,7 +126,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
126
126
if isinstance (linewidth , np .ndarray ):
127
127
if linewidth .shape != grid .shape :
128
128
raise ValueError ("If 'linewidth' is given, it must match the "
129
- "shape of 'Grid (x, y)' " )
129
+ "shape of the (x, y) grid " )
130
130
line_kw ['linewidth' ] = []
131
131
else :
132
132
line_kw ['linewidth' ] = linewidth
@@ -137,7 +137,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
137
137
138
138
# Sanity checks.
139
139
if u .shape != grid .shape or v .shape != grid .shape :
140
- raise ValueError ("'u' and 'v' must match the shape of 'Grid (x, y)' " )
140
+ raise ValueError ("'u' and 'v' must match the shape of the (x, y) grid " )
141
141
142
142
u = np .ma .masked_invalid (u )
143
143
v = np .ma .masked_invalid (v )
@@ -310,21 +310,22 @@ class Grid:
310
310
"""Grid of data."""
311
311
def __init__ (self , x , y ):
312
312
313
- if x .ndim == 1 :
313
+ if np .ndim ( x ) == 1 :
314
314
pass
315
- elif x .ndim == 2 :
316
- x_row = x [0 , : ]
315
+ elif np .ndim ( x ) == 2 :
316
+ x_row = x [0 ]
317
317
if not np .allclose (x_row , x ):
318
318
raise ValueError ("The rows of 'x' must be equal" )
319
319
x = x_row
320
320
else :
321
321
raise ValueError ("'x' can have at maximum 2 dimensions" )
322
322
323
- if y .ndim == 1 :
323
+ if np .ndim ( y ) == 1 :
324
324
pass
325
- elif y .ndim == 2 :
326
- y_col = y [:, 0 ]
327
- if not np .allclose (y_col , y .T ):
325
+ elif np .ndim (y ) == 2 :
326
+ yt = np .transpose (y ) # Also works for nested lists.
327
+ y_col = yt [0 ]
328
+ if not np .allclose (y_col , yt ):
328
329
raise ValueError ("The columns of 'y' must be equal" )
329
330
y = y_col
330
331
else :
0 commit comments