8000 Address remaining quiver3d issues. Fix tests. · matplotlib/matplotlib@3112e67 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3112e67

Browse files
committed
Address remaining quiver3d issues. Fix tests.
1 parent 1229099 commit 3112e67

File tree

12 files changed

+33335
-16335
lines changed

12 files changed

+33335
-16335
lines changed

examples/mplot3d/quiver3d_demo.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
u = np.sin(np.pi * x) * np.cos(np.pi * y) * np.cos(np.pi * z)
1313
v = -np.cos(np.pi * x) * np.sin(np.pi * y) * np.cos(np.pi * z)
14-
w = np.sqrt(2.0 / 3.0) * np.cos(np.pi * x) * np.cos(np.pi * y) * \
15-
np.sin(np.pi * z)
14+
w = (np.sqrt(2.0 / 3.0) * np.cos(np.pi * x) * np.cos(np.pi * y) *
15+
np.sin(np.pi * z))
1616

1717
ax.quiver(x, y, z, u, v, w, length=0.1)
1818

19-
plt.show()
19+
plt.show()
20+

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from . import art3d
3535
from . import proj3d
3636
from . import axis3d
37-
from mpl_toolkits.mplot3d.art3d import Line3DCollection
3837

3938
def unit_bbox():
4039
box = Bbox(np.array([[0, 0], [1, 1]]))
@@ -2431,19 +2430,23 @@ def quiver(self, *args, **kwargs):
24312430
*U*, *V*, *W*:
24322431
The direction vector that the arrow is pointing
24332432
2434-
The arguments could be iterable or scalars they will be broadcast together. The arguments can
2435-
also be masked arrays, if a position in any of argument is masked, then the corresponding
2436-
quiver will not be plotted.
2433+
The arguments could be array-like or scalars, so long as they
2434+
they can be broadcast together. The arguments can also be
2435+
masked arrays. If an element in any of argument is masked, then
2436+
that corresponding quiver element will not be plotted.
24372437
24382438
Keyword arguments:
24392439
24402440
*length*: [1.0 | float]
2441-
The length of each quiver, default to 1.0, the unit is the same with the axes
2441+
The length of each quiver, default to 1.0, the unit is
2442+
the same with the axes
24422443
24432444
*arrow_length_ratio*: [0.3 | float]
2444-
The ratio of the arrow head with respect to the quiver, default to 0.3
2445+
The ratio of the arrow head with respect to the quiver,
2446+
default to 0.3
24452447
2446-
Any additional keyword arguments are delegated to :class:`~matplotlib.collections.LineCollection`
2448+
Any additional keyword arguments are delegated to
2449+
:class:`~matplotlib.collections.LineCollection`
24472450
24482451
"""
24492452
def calc_arrow(u, v, w, angle=15):
@@ -2472,8 +2475,8 @@ def rotatefunction(angle):
24722475

24732476
# construct the rotation matrix
24742477
R = np.matrix([[c+(x**2)*(1-c), x*y*(1-c)-z*s, x*z*(1-c)+y*s],
2475-
[y*x*(1-c)+z*s, c+(y**2)*(1-c), y*z*(1-c)-x*s],
2476-
[z*x*(1-c)-y*s, z*y*(1-c)+x*s, c+(z**2)*(1-c)]])
2478+
[y*x*(1-c)+z*s, c+(y**2)*(1-c), y*z*(1-c)-x*s],
2479+
[z*x*(1-c)-y*s, z*y*(1-c)+x*s, c+(z**2)*(1-c)]])
24772480

24782481
# construct the column vector for (u,v,w)
24792482
line = np.matrix([[u],[v],[w]])
@@ -2512,7 +2515,9 @@ def point_vector_to_line(point, vector, length):
25122515
# first 6 arguments are X, Y, Z, U, V, W
25132516
input_args = args[:argi]
25142517
# if any of the args are scalar, convert into list
2515-
input_args = [[k] if isinstance(k, (int, float)) else k for k in input_args]
2518+
input_args = [[k] if isinstance(k, (int, float)) else k
2519+
for k in input_args]
2520+
25162521
# extract the masks, if any
25172522
masks = [k.mask for k in input_args if isinstance(k, np.ma.MaskedArray)]
25182523
# broadcast to match the shape
@@ -2523,36 +2528,43 @@ def point_vector_to_line(point, vector, length):
25232528
# combine the masks into one
25242529
mask = reduce(np.logical_or, masks)
25252530
# put mask on and compress
2526-
input_args = [np.ma.array(k, mask=mask).compressed() for k in input_args]
2531+
input_args = [np.ma.array(k, mask=mask).compressed()
2532+
for k in input_args]
25272533
else:
25282534
input_args = [k.flatten() for k in input_args]
25292535

2536+
if any(len(v) == 0 for v in input_args):
2537+
# No quivers, so just make an empty collection and return early
2538+
linec = art3d.Line3DCollection([], *args[6:], **kwargs)
2539+
self.add_collection(linec)
2540+
return linec
2541+
25302542
points = input_args[:3]
25312543
vectors = input_args[3:]
25322544

25332545
# Below assertions must be true before proceed
25342546
# must all be ndarray
2535-
assert all([isinstance(k, np.ndarray) for k in input_args])
2547+
assert all(isinstance(k, np.ndarray) for k in input_args)
25362548
# must all in same shape
25372549
assert len(set([k.shape for k in input_args])) == 1
25382550

2539-
25402551
# X, Y, Z, U, V, W
2541-
coords = list(map(lambda k: np.array(k) if not isinstance(k, np.ndarray) else k, args))
2552+
coords = (np.array(k) if not isinstance(k, np.ndarray) else k
2553+
for k in args)
25422554
coords = [k.flatten() for k in coords]
25432555
xs, ys, zs, us, vs, ws = coords
25442556
lines = []
25452557

25462558
# for each arrow
2547-
for i in xrange(xs.shape[0]):
2559+
for i in range(xs.shape[0]):
25482560
# calulate body
25492561
x = xs[i]
25502562
y = ys[i]
25512563
z = zs[i]
25522564
u = us[i]
25532565
v = vs[i]
25542566
w = ws[i]
2555-
if any([k is np.ma.masked for k in [x, y, z, u, v, w]]):
2567+
if any(k is np.ma.masked for k in [x, y, z, u, v, w]):
25562568
continue
25572569

25582570
# (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
@@ -2590,7 +2602,7 @@ def point_vector_to_line(point, vector, length):
25902602
line = list(zip(la2x, la2y, la2z))
25912603
lines.append(line)
25922604

2593-
linec = Line3DCollection(lines, *args[6:], **kwargs)
2605+
linec = art3d.Line3DCollection(lines, *args[6:], **kwargs)
25942606
self.add_collection(linec)
25952607

25962608
self.auto_scale_xyz(xs, ys, zs, had_data)
Binary file not shown.
Loading

0 commit comments

Comments
 (0)
0