@@ -2531,35 +2531,45 @@ def quiver(self, *args,
2531
2531
Any additional keyword arguments are delegated to
2532
2532
:class:`~matplotlib.collections.LineCollection`
2533
2533
"""
2534
- def calc_arrow (uvw , angle = 15 ):
2535
- """
2536
- To calculate the arrow head. uvw should be a unit vector.
2537
- We normalize it here:
2538
- """
2539
- # get unit direction vector perpendicular to (u, v, w)
2540
- norm = np .linalg .norm (uvw [:2 ])
2541
- if norm > 0 :
2542
- x = uvw [1 ] / norm
2543
- y = - uvw [0 ] / norm
2544
- else :
2545
- x , y = 0 , 1
2534
+ def calc_arrows (UVW , angle = 15 ):
2535
+ # get unit direction vector perpendicular to (u,v,w)
2536
+ x = UVW [:, 0 ]
2537
+ y = UVW [:, 1 ]
2538
+ norm = np .linalg .norm (UVW [:, :2 ], axis = 1 )
2539
+ x_p = np .divide (y , norm , where = norm != 0 , out = np .zeros_like (x ))
2540
+ y_p = np .divide (- x , norm , where = norm != 0 , out = np .ones_like (x ))
2546
2541
2547
2542
# compute the two arrowhead direction unit vectors
2548
2543
ra = math .radians (angle )
2549
2544
c = math .cos (ra )
2550
2545
s = math .sin (ra )
2551
2546
2552
2547
# construct the rotation matrices
2553
- Rpos = np .array ([[c + (x ** 2 )* (1 - c ), x * y * (1 - c ), y * s ],
2554
- [y * x * (1 - c ), c + (y ** 2 )* (1 - c ), - x * s ],
2555
- [- y * s , x * s , c ]])
2548
+ Rpos = np .array (
2549
+ [[c + (x_p ** 2 ) * (1 - c ), x_p * y_p * (1 - c ), y_p * s ],
2550
+ [y_p * x_p * (1 - c ), c + (y_p ** 2 ) * (1 - c ), - x_p * s ],
2551
+ [- y_p * s , x_p * s , np .full_like (x_p , c )]])
2552
+ Rpos = Rpos .transpose (2 , 0 , 1 )
2553
+
2556
2554
# opposite rotation negates all the sin terms
2557
2555
Rneg = Rpos .copy ()
2558
- Rneg [[0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]] = \
2559
- - Rneg [[0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]]
2556
+ Rneg [:, [0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]] = \
2557
+ - Rneg [:, [0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]]
2558
+
2559
+ # expand dimensions for batched matrix multiplication
2560
+ UVW = np .expand_dims (UVW , axis = - 1 )
2560
2561
2561
2562
# multiply them to get the rotated vector
2562
- return Rpos .dot (uvw ), Rneg .dot (uvw )
2563
+ Rpos_vecs = np .matmul (Rpos , UVW )
2564
+ Rneg_vecs = np .matmul (Rneg , UVW )
2565
+
2566
+ # transpose for concatenation
2567
+ Rpos_vecs = Rpos_vecs .transpose (0 , 2 , 1 )
2568
+ Rneg_vecs = Rneg_vecs .transpose (0 , 2 , 1 )
2569
+
2570
+ head_dirs = np .concatenate ([Rpos_vecs , Rneg_vecs ], axis = 1 )
2571
+
2572
+ return head_dirs
2563
2573
2564
2574
had_data = self .has_data ()
2565
2575
@@ -2621,7 +2631,7 @@ def calc_arrow(uvw, angle=15):
2621
2631
# compute the shaft lines all at once with an outer product
2622
2632
shafts = (XYZ - np .multiply .outer (shaft_dt , UVW )).swapaxes (0 , 1 )
2623
2633
# compute head direction vectors, n heads x 2 sides x 3 dimensions
2624
- head_dirs = np . array ([ calc_arrow ( d ) for d in UVW ] )
2634
+ head_dirs = calc_arrows ( UVW )
2625
2635
# compute all head lines at once, starting from the shaft ends
2626
2636
heads = shafts [:, :1 ] - np .multiply .outer (arrow_dt , head_dirs )
2627
2637
# stack left and right head lines together
0 commit comments