@@ -49,16 +49,17 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
4949 Minimum length of streamline in axes coordinates.
5050
5151 Returns:
52-
52+
5353 *stream_container* : StreamplotSet
5454 Container object with attributes
55- lines : `matplotlib.collections.LineCollection` of streamlines
56- arrows : collection of `matplotlib.patches.FancyArrowPatch` objects
57- repesenting arrows half-way along stream lines.
58- This container will probably change in the future to allow changes to
59- the colormap, alpha, etc. for both lines and arrows, but these changes
60- should be backward compatible.
61-
55+ lines: `matplotlib.collections.LineCollection` of streamlines
56+ arrows: collection of `matplotlib.patches.FancyArrowPatch`
57+ objects representing arrows half-way along stream
58+ lines.
59+ This container will probably change in the future to allow changes
60+ to the colormap, alpha, etc. for both lines and arrows, but these
61+ changes should be backward compatible.
62+
6263 """
6364 grid = Grid (x , y )
6465 mask = StreamMask (density )
@@ -71,7 +72,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
7172 linewidth = matplotlib .rcParams ['lines.linewidth' ]
7273
7374 line_kw = {}
74- arrow_kw = dict (arrowstyle = arrowstyle , mutation_scale = 10 * arrowsize )
75+ arrow_kw = dict (arrowstyle = arrowstyle , mutation_scale = 10 * arrowsize )
7576
7677 use_multicolor_lines = isinstance (color , np .ndarray )
7778 if use_multicolor_lines :
@@ -104,7 +105,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
104105 if mask [ym , xm ] == 0 :
105106 xg , yg = dmap .mask2grid (xm , ym )
106107 t = integrate (xg , yg )
107- if t != None :
108+ if t is not None :
108109 trajectories .append (t )
109110
110111 if use_multicolor_lines :
<
10BC0
/code>@@ -128,10 +129,10 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
128129 streamlines .extend (np .hstack ([points [:- 1 ], points [1 :]]))
129130
130131 # Add arrows half way along each trajectory.
131- s = np .cumsum (np .sqrt (np .diff (tx )** 2 + np .diff (ty )** 2 ))
132+ s = np .cumsum (np .sqrt (np .diff (tx ) ** 2 + np .diff (ty ) ** 2 ))
132133 n = np .searchsorted (s , s [- 1 ] / 2. )
133134 arrow_tail = (tx [n ], ty [n ])
134- arrow_head = (np .mean (tx [n :n + 2 ]), np .mean (ty [n :n + 2 ]))
135+ arrow_head = (np .mean (tx [n :n + 2 ]), np .mean (ty [n :n + 2 ]))
135136
136137 if isinstance (linewidth , np .ndarray ):
137138 line_widths = interpgrid (linewidth , tgx , tgy )[:- 1 ]
@@ -143,15 +144,15 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
143144 line_colors .extend (color_values )
144145 arrow_kw ['color' ] = cmap (norm (color_values [n ]))
145146
146- p = patches .FancyArrowPatch (arrow_tail ,
147- arrow_head ,
148- transform = transform ,
147+ p = patches .FancyArrowPatch (arrow_tail ,
148+ arrow_head ,
149+ transform = transform ,
149150 ** arrow_kw )
150151 axes .add_patch (p )
151152 arrows .append (p )
152153
153- lc = mcollections .LineCollection (streamlines ,
154- transform = transform ,
154+ lc = mcollections .LineCollection (streamlines ,
155+ transform = transform ,
155156 ** line_kw )
156157 if use_multicolor_lines :
157158 lc .set_array (np .asarray (line_colors ))
@@ -275,7 +276,7 @@ def within_grid(self, xi, yi):
275276 """Return True if point is a valid index of grid."""
276277 # Note that xi/yi can be floats; so, for example, we can't simply check
277278 # `xi < self.nx` since `xi` can be `self.nx - 1 < xi < self.nx`
278- return xi >= 0 and xi <= self .nx - 1 and yi >= 0 and yi <= self .ny - 1
279+ return xi >= 0 and xi <= self .nx - 1 and yi >= 0 and yi <= self .ny - 1
279280
280281
281282class StreamMask (object ):
@@ -330,6 +331,7 @@ def _update_trajectory(self, xm, ym):
330331class InvalidIndexError (Exception ):
331332 pass
332333
334+
333335class TerminateTrajectory (Exception ):
334336 pass
335337
@@ -345,7 +347,7 @@ def get_integrator(u, v, dmap, minlength):
345347 # speed (path length) will be in axes-coordinates
346348 u_ax = u / dmap .grid .nx
347349 v_ax = v / dmap .grid .ny
348- speed = np .ma .sqrt (u_ax ** 2 + v_ax ** 2 )
350+ speed = np .ma .sqrt (u_ax ** 2 + v_ax ** 2 )
349351
350352 def forward_time (xi , yi ):
351353 ds_dt = interpgrid (speed , xi , yi )
@@ -382,7 +384,7 @@ def integrate(x0, y0):
382384
383385 if stotal > minlength :
384386 return x_traj , y_traj
385- else : # reject short trajectories
387+ else : # reject short trajectories
386388 dmap .undo_trajectory ()
387389 return None
388390
@@ -423,7 +425,7 @@ def _integrate_rk12(x0, y0, dmap, f):
423425 ## increment the location gradually. However, due to the efficient
424426 ## nature of the interpolation, this doesn't boost speed by much
425427 ## for quite a bit of complexity.
426- maxds = min (1. / dmap .mask .nx , 1. / dmap .mask .ny , 0.1 )
428+ maxds = min (1. / dmap .mask .nx , 1. / dmap .mask .ny , 0.1 )
427429
428430 ds = maxds
429431 stotal = 0
@@ -455,7 +457,7 @@ def _integrate_rk12(x0, y0, dmap, f):
455457
456458 nx , ny = dmap .grid .shape
457459 # Error is normalized to the axes coordinates
458- error = np .sqrt (((dx2 - dx1 )/ nx )** 2 + ((dy2 - dy1 )/ ny )** 2 )
460+ error = np .sqrt (((dx2 - dx1 ) / nx ) ** 2 + ((dy2 - dy1 ) / ny ) ** 2 )
459461
460462 # Only save step if within error tolerance
461463 if error < maxerror :
@@ -473,7 +475,7 @@ def _integrate_rk12(x0, y0, dmap, f):
473475 if error == 0 :
474476 ds = maxds
475477 else :
476- ds = min (maxds , 0.85 * ds * (maxerror / error )** 0.5 )
478+ ds = min (maxds , 0.85 * ds * (maxerror / error ) ** 0.5 )
477479
478480 return stotal , xf_traj , yf_traj
479481
@@ -497,8 +499,8 @@ def _euler_step(xf_traj, yf_traj, dmap, f):
497499 else :
498500 dsy = (ny - 1 - yi ) / cy
499501 ds = min (dsx , dsy )
500- xf_traj .append (xi + cx * ds )
501- yf_traj .append (yi + cy * ds )
502+ xf_traj .append (xi + cx * ds )
503+ yf_traj .append (yi + cy * ds )
502504 return ds , xf_traj , yf_traj
503505
504506
@@ -519,10 +521,14 @@ def interpgrid(a, xi, yi):
519521 x = np .int (xi )
520522 y = np .int (yi )
521523 # conditional is faster than clipping for integers
522- if x == (Nx - 2 ): xn = x
523- else : xn = x + 1
524- if y == (Ny - 2 ): yn = y
525- else : yn = y + 1
524+ if x == (Nx - 2 ):
525+ xn = x
526+ else :
527+ xn = x + 1
528+ if y == (Ny - 2 ):
529+ yn = y
530+ else :
531+ yn = y + 1
526532
527533 a00 = a [y , x ]
528534 a01 = a [y , xn ]
@@ -563,20 +569,20 @@ def _gen_starting_points(shape):
563569 if direction == 'right' :
564570 x += 1
565571 if x >= xlast :
566- xlast -= 1
572+ xlast -= 1
567573 direction = 'up'
568574 elif direction == 'up' :
569575 y += 1
570576 if y >= ylast :
571- ylast -= 1
577+ ylast -= 1
572578 direction = 'left'
573579 elif direction == 'left' :
574580 x -= 1
575581 if x <= xfirst :
576- xfirst += 1
582+ xfirst += 1
577583 direction = 'down'
578584 elif direction == 'down' :
579585 y -= 1
580586 if y <= yfirst :
581- yfirst += 1
587+ yfirst += 1
582588 direction = 'right'
0 commit comments