12
12
import warnings
13
13
from math import pi
14
14
15
+ import matplotlib as mpl
15
16
import matplotlib .pyplot as plt
16
17
import numpy as np
17
18
import pytest
@@ -138,29 +139,39 @@ def invpend_ode(t, x, m=0, l=0, b=0, g=0):
138
139
139
140
# Use callable form, with parameters (if not correct, will get /0 error)
140
141
ct .phase_plane_plot (
141
- invpend_ode , [- 5 , 5 , - 2 , 2 ], params = {'args' : (1 , 1 , 0.2 , 1 )})
142
+ invpend_ode , [- 5 , 5 , - 2 , 2 ], params = {'args' : (1 , 1 , 0.2 , 1 )},
143
+ plot_streamlines = True )
142
144
143
145
# Linear I/O system
144
146
ct .phase_plane_plot (
145
- ct .ss ([[0 , 1 ], [- 1 , - 1 ]], [[0 ], [1 ]], [[1 , 0 ]], 0 ))
147
+ ct .ss ([[0 , 1 ], [- 1 , - 1 ]], [[0 ], [1 ]], [[1 , 0 ]], 0 ),
148
+ plot_streamlines = True )
146
149
147
150
148
151
@pytest .mark .usefixtures ('mplcleanup' )
149
152
def test_phaseplane_errors ():
150
153
with pytest .raises (ValueError , match = "invalid grid specification" ):
151
- ct .phase_plane_plot (ct .r
F438
ss (2 , 1 , 1 ), gridspec = 'bad' )
154
+ ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridspec = 'bad' ,
155
+ plot_streamlines = True )
152
156
153
157
with pytest .raises (ValueError , match = "unknown grid type" ):
154
- ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridtype = 'bad' )
158
+ ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridtype = 'bad' ,
159
+ plot_streamlines = True )
155
160
156
161
with pytest .raises (ValueError , match = "system must be planar" ):
157
- ct .phase_plane_plot (ct .rss (3 , 1 , 1 ))
162
+ ct .phase_plane_plot (ct .rss (3 , 1 , 1 ),
163
+ plot_streamlines = True )
158
164
159
165
with pytest .raises (ValueError , match = "params must be dict with key" ):
160
166
def invpend_ode (t , x , m = 0 , l = 0 , b = 0 , g = 0 ):
161
167
return (x [1 ], - b / m * x [1 ] + (g * l / m ) * np .sin (x [0 ]))
162
168
ct .phase_plane_plot (
163
- invpend_ode , [- 5 , 5 , 2 , 2 ], params = {'stuff' : (1 , 1 , 0.2 , 1 )})
169
+ invpend_ode , [- 5 , 5 , 2 , 2 ], params = {'stuff' : (1 , 1 , 0.2 , 1 )},
170
+ plot_streamlines = True )
171
+
172
+ with pytest .raises (ValueError , match = "gridtype must be 'meshgrid' when using streamplot" ):
173
+ ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), plot_streamlines = False ,
174
+ plot_streamplot = True , gridtype = 'boxgrid' )
164
175
165
176
# Warning messages for invalid solutions: nonlinear spring mass system
166
177
sys = ct .nlsys (
@@ -171,14 +182,87 @@ def invpend_ode(t, x, m=0, l=0, b=0, g=0):
171
182
UserWarning , match = r"initial_state=\[.*\], solve_ivp failed" ):
172
183
ct .phase_plane_plot (
173
184
sys , [- 12 , 12 , - 10 , 10 ], 15 , gridspec = [2 , 9 ],
174
- plot_separatrices = False )
185
+ plot_separatrices = False , plot_streamlines = True )
175
186
176
187
# Turn warnings off
177
188
with warnings .catch_warnings ():
178
189
warnings .simplefilter ("error" )
179
190
ct .phase_plane_plot (
180
191
sys , [- 12 , 12 , - 10 , 10 ], 15 , gridspec = [2 , 9 ],
181
- plot_separatrices = False , suppress_warnings = True )
192
+ plot_streamlines = True , plot_separatrices = False ,
193
+ suppress_warnings = True )
194
+
195
+ @pytest .mark .usefixtures ('mplcleanup' )
196
+ def test_phase_plot_zorder ():
197
+ # some of these tests are a bit akward since the streamlines and separatrices
198
+ # are stored in the same list, so we separate them by color
199
+ key_color = "tab:blue" # must not be 'k', 'r', 'b' since they are used by separatrices
200
+
201
+ def get_zorders (cplt ):
202
+ max_zorder = lambda items : max ([line .get_zorder () for line in items ])
203
+ assert isinstance (cplt .lines [0 ], list )
204
+ streamline_lines = [line for line in cplt .lines [0 ] if line .get_color () == key_color ]
205
+ separatrice_lines = [line for line in cplt .lines [0 ] if line .get_color () != key_color ]
206
+ streamlines = max_zorder (streamline_lines ) if streamline_lines else None
207
+ separatrices = max_zorder (separatrice_lines ) if separatrice_lines else None
208
+ assert cplt .lines [1 ] == None or isinstance (cplt .lines [1 ], mpl .quiver .Quiver )
209
+ quiver = cplt .lines [1 ].get_zorder () if cplt .lines [1 ] else None
210
+ assert cplt .lines [2 ] == None or isinstance (cplt .lines [2 ], list )
211
+ equilpoints = max_zorder (cplt .lines [2 ]) if cplt .lines [2 ] else None
212
+ assert cplt .lines [3 ] == None or isinstance (cplt .lines [3 ], mpl .streamplot .StreamplotSet )
213
+ streamplot = max (cplt .lines [3 ].lines .get_zorder (), cplt .lines [3 ].arrows .get_zorder ()) if cplt .lines [3 ] else None
214
+ return streamlines , quiver , streamplot , separatrices , equilpoints
215
+
216
+ def assert_orders (streamlines , quiver , streamplot , separatrices , equilpoints ):
217
+ print (streamlines , quiver , streamplot , separatrices , equilpoints )
218
+ if streamlines is not None :
219
+ assert streamlines < separatrices < equilpoints
220
+ if quiver is not None :
221
+ assert quiver < separatrices < equilpoints
222
+ if streamplot is not None :
223
+ assert streamplot < separatrices < equilpoints
224
+
225
+ def sys (t , x ):
226
+ return np .array ([4 * x [1 ], - np .sin (4 * x [0 ])])
227
+
228
+ # ensure correct zordering for all three flow types
229
+ res_streamlines = ct .phase_plane_plot (sys , plot_streamlines = dict (color = key_color ))
230
+ assert_orders (* get_zorders (res_streamlines ))
231
+ res_vectorfield = ct .phase_plane_plot (sys , plot_vectorfield = True )
232
+ assert_orders (* get_zorders (res_vectorfield ))
233
+ res_streamplot = ct .phase_plane_plot (sys , plot_streamplot =True )
234
+ assert_orders (* get_zorders (res_streamplot ))
235
+
236
+ # ensure that zorder can still be overwritten
237
+ res_reversed = ct .phase_plane_plot (sys , plot_streamlines = dict (color = key_color , zorder = 50 ), plot_vectorfield = dict (zorder = 40 ),
238
+ plot_streamplot = dict (zorder = 30 ), plot_separatrices = dict (zorder = 20 ), plot_equilpoints = dict (zorder = 10 ))
239
+ streamlines , quiver , streamplot , separatrices , equilpoints = get_zorders (res_reversed )
240
+ assert streamlines > quiver > streamplot > separatrices > equilpoints
241
+
242
+
243
+ @pytest .mark .usefixtures ('mplcleanup' )
244
+ def test_stream_plot_magnitude ():
245
+ def sys (t , x ):
246
+ return np .array ([4 * x [1 ], - np .sin (4 * x [0 ])])
247
+
248
+ # plt context with linewidth
249
+ with plt .rc_context ({'lines.linewidth' : 4 }):
250
+ res = ct .phase_plane_plot (sys , plot_streamplot = dict (vary_linewidth = True ))
251
+ linewidths = res .lines [3 ].lines .get_linewidths ()
252
+ # linewidths are scaled to be between 0.25 and 2 times default linewidth
253
+ # but the extremes may not exist if there is no line at that point
254
+ assert min (linewidths ) < 2 and max (linewidths ) > 7
255
+
256
+ # make sure changing the colormap works
257
+ res = ct .phase_plane_plot (sys , plot_streamplot = dict (vary_color = True , cmap = 'viridis' ))
258
+ assert res .lines [3 ].lines .get_cmap ().name == 'viridis'
259
+ res = ct .phase_plane_plot (sys , plot_streamplot = dict (vary_color = True , cmap = 'turbo' ))
260
+ assert res .lines [3 ].lines .get_cmap ().name == 'turbo'
261
+
262
+ # make sure changing the norm at least doesn't throw an error
263
+ ct .phase_plane_plot (sys , plot_streamplot = dict (vary_color = True , norm = mpl .colors .LogNorm ()))
264
+
265
+
182
266
183
267
184
268
@pytest .mark .usefixtures ('mplcleanup' )
@@ -190,7 +274,7 @@ def test_basic_phase_plots(savefigs=False):
190
274
plt .figure ()
191
275
axis_limits = [- 1 , 1 , - 1 , 1 ]
192
276
T = 8
193
- ct .phase_plane_plot (sys , axis_limits , T )
277
+ ct .phase_plane_plot (sys , axis_limits , T , plot_streamlines = True )
194
278
if savefigs :
195
279
plt .savefig ('phaseplot-dampedosc-default.png' )
196
280
@@ -203,7 +287,7 @@ def invpend_update(t, x, u, params):
203
287
ct .phase_plane_plot (
204
288
invpend , [- 2 * pi , 2 * pi , - 2 , 2 ], 5 ,
205
289
gridtype = 'meshgrid' , gridspec = [5 , 8 ], arrows = 3 ,
206
- plot_separatrices = {'gridspec' : [12 , 9 ]},
290
+ plot_separatrices = {'gridspec' : [12 , 9 ]}, plot_streamlines = True ,
207
291
params = {'m' : 1 , 'l' : 1 , 'b' : 0.2 , 'g' : 1 })
208
292
plt .xlabel (r"$\theta$ [rad]" )
209
293
plt .ylabel (r"$\dot\theta$ [rad/sec]" )
@@ -218,7 +302,8 @@ def oscillator_update(t, x, u, params):
218
302
oscillator_update , states = 2 , inputs = 0 , name = 'nonlinear oscillator' )
219
303
220
304
plt .figure ()
221
- ct .phase_plane_plot (oscillator , [- 1.5 , 1.5 , - 1.5 , 1.5 ], 0.9 )
305
+ ct .phase_plane_plot (oscillator , [- 1.5 , 1.5 , - 1.5 , 1.5 ], 0.9 ,
306
+ plot_streamlines = True )
222
307
pp .streamlines (
223
308
oscillator , np .array ([[0 , 0 ]]), 1.5 ,
224
309
gridtype = 'circlegrid' , gridspec = [0.5 , 6 ], dir = 'both' )
@@ -228,6 +313,18 @@ def oscillator_update(t, x, u, params):
228
313
if savefigs :
229
314
plt .savefig ('phaseplot-oscillator-helpers.png' )
230
315
316
+ plt .figure ()
317
+ ct .phase_plane_plot (
318
+ invpend , [- 2 * pi , 2 * pi , - 2 , 2 ],
319
+ plot_streamplot = dict (vary_color = True , vary_density = True ),
320
+ gridspec = [60 , 20 ], params = {'m' : 1 , 'l' : 1 , 'b' : 0.2 , 'g' : 1 }
321
+ )
322
+ plt .xlabel (r"$\theta$ [rad]" )
323
+ plt .ylabel (r"$\dot\theta$ [rad/sec]" )
324
+
325
+ if savefigs :
326
+ plt .savefig ('phaseplot-invpend-streamplot.png' )
327
+
231
328
232
329
if __name__ == "__main__" :
233
330
#
0 commit comments