8000 vectorized calc_arrow loop in quiver (#15346) · matplotlib/matplotlib@327df3b · GitHub
[go: up one dir, main page]

Skip to content

Commit 327df3b

Browse files
authored
vectorized calc_arrow loop in quiver (#15346)
vectorized calc_arrow loop in quiver
2 parents 46c4585 + 01b0eb9 commit 327df3b

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2531,35 +2531,45 @@ def quiver(self, *args,
25312531
Any additional keyword arguments are delegated to
25322532
:class:`~matplotlib.collections.LineCollection`
25332533
"""
2534-
def calc_arrow(uvw, angle=15):
2535-
"""
2536-
To calculate the arrow head. uvw should be a unit vector.
2537-
We normalize it here:
2538-
"""
2539-
# get unit direction vector perpendicular to (u, v, w)
2540-
norm = np.linalg.norm(uvw[:2])
2541-
if norm > 0:
2542-
x = uvw[1] / norm
2543-
y = -uvw[0] / norm
2544-
else:
2545-
x, y = 0, 1
2534+
def calc_arrows(UVW, angle=15):
2535+
# get unit direction vector perpendicular to (u,v,w)
2536+
x = UVW[:, 0]
2537+
y = UVW[:, 1]
2538+
norm = np.linalg.norm(UVW[:, :2], axis=1)
2539+
x_p = np.divide(y, norm, where=norm != 0, out=np.zeros_like(x))
2540+
y_p = np.divide(-x, norm, where=norm != 0, out=np.ones_like(x))
25462541

25472542
# compute the two arrowhead direction unit vectors
25482543
ra = math.radians(angle)
25492544
c = math.cos(ra)
25502545
s = math.sin(ra)
25512546

25522547
# construct the rotation matrices
2553-
Rpos = np.array([[c+(x**2)*(1-c), x*y*(1-c), y*s],
2554-
[y*x*(1-c), c+(y**2)*(1-c), -x*s],
2555-
[-y*s, x*s, c]])
2548+
Rpos = np.array(
2549+
[[c + (x_p ** 2) * (1 - c), x_p * y_p * (1 - c), y_p * s],
2550+
[y_p * x_p * (1 - c), c + (y_p ** 2) * (1 - c), -x_p * s],
2551+
[-y_p * s, x_p * s, np.full_like(x_p, c)]])
2552+
Rpos = Rpos.transpose(2, 0, 1)
2553+
25562554
# opposite rotation negates all the sin terms
25572555
Rneg = Rpos.copy()
2558-
Rneg[[0, 1, 2, 2], [2, 2, 0, 1]] = \
2559-
-Rneg[[0, 1, 2, 2], [2, 2, 0, 1]]
2556+
Rneg[:, [0, 1, 2, 2], [2, 2, 0, 1]] = \
2557+
-Rneg[:, [0, 1, 2, 2], [2, 2, 0, 1]]
2558+
2559+
# expand dimensions for batched matrix multiplication
2560+
UVW = np.expand_dims(UVW, axis=-1)
25602561

25612562
# multiply them to get the rotated vector
2562-
return Rpos.dot(uvw), Rneg.dot(uvw)
2563+
Rpos_vecs = np.matmul(Rpos, UVW)
2564+
Rneg_vecs = np.matmul(Rneg, UVW)
2565+
2566+
# transpose for concatenation
2567+
Rpos_vecs = Rpos_vecs.transpose(0, 2, 1)
2568+
Rneg_vecs = Rneg_vecs.transpose(0, 2, 1)
2569+
2570+
head_dirs = np.concatenate([Rpos_vecs, Rneg_vecs], axis=1)
2571+
2572+
return head_dirs
25632573

25642574
had_data = self.has_data()
25652575

@@ -2621,7 +2631,7 @@ def calc_arrow(uvw, angle=15):
26212631
# compute the shaft lines all at once with an outer product
26222632
shafts = (XYZ - np.multiply.outer(shaft_dt, UVW)).swapaxes(0, 1)
26232633
# compute head direction vectors, n heads x 2 sides x 3 dimensions
2624-
head_dirs = np.array([calc_arrow(d) for d in UVW])
2634+
head_dirs = calc_arrows(UVW)
26252635
# compute all head lines at once, starting from the shaft ends
26262636
heads = shafts[:, :1] - np.multiply.outer(arrow_dt, head_dirs)
26272637
# stack left and right head lines together

0 commit comments

Comments
 (0)
0