8000 Merge pull request #1145 · python-control/python-control@632391c · GitHub
[go: up one dir, main page]

Skip to content

Commit 632391c

Browse files
authored
Merge pull request #1145
fix ax processing bug in {nyquist,nichols,describing_function}_plot
2 parents dc7d71b + 21f4912 commit 632391c

File tree

5 files changed

+34
-26
lines changed

5 files changed

+34
-26
lines changed

control/ctrlplot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def _process_ax_keyword(
355355
the calling function to do the actual axis creation (needed for
356356
curvilinear grids that use the AxisArtist module).
357357
358-
Legacy behavior: some of the older plotting commands use a axes label
358+
Legacy behavior: some of the older plotting commands use an axes label
359359
to identify the proper axes for plotting. This behavior is supported
360360
through the use of the label keyword, but will only work if shape ==
361361
(1, 1) and squeeze == True.

control/descfcn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import math
1010
from warnings import warn
1111

12-
import matplotlib.pyplot as plt
1312
import numpy as np
1413
import scipy
1514

@@ -521,16 +520,17 @@ def describing_function_plot(
521520

522521
# Plot the Nyquist response
523522
cplt = dfresp.response.plot(**kwargs)
523+
ax = cplt.axes[0, 0] # Get the axes where the plot was made
524524
lines[0] = cplt.lines[0] # Return Nyquist lines for first system
525525

526526
# Add the describing function curve to the plot
527-
lines[1] = plt.plot(dfresp.N_vals.real, dfresp.N_vals.imag)
527+
lines[1] = ax.plot(dfresp.N_vals.real, dfresp.N_vals.imag)
528528

529529
# Label the intersection points
530530
if point_label:
531531
for pos, (a, omega) in zip(dfresp.positions, dfresp.intersections):
532532
# Add labels to the intersection points
533-
plt.text(pos.real, pos.imag, point_label % (a, omega))
533+
ax.text(pos.real, pos.imag, point_label % (a, omega))
534534

535535
return ControlPlot(lines, cplt.axes, cplt.figure)
536536

control/freqplot.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,7 +1913,7 @@ def _parse_linestyle(style_name, allow_false=False):
19131913
# Plot the regular portions of the curve (and grab the color)
19141914
x_reg = np.ma.masked_where(reg_mask, resp.real)
19151915
y_reg = np.ma.masked_where(reg_mask, resp.imag)
1916-
p = plt.plot(
1916+
p = ax.plot(
19171917
x_reg, y_reg, primary_style[0], color=color, label=label, **kwargs)
19181918
c = p[0].get_color()
19191919
out[idx] += p
@@ -1928,7 +1928,7 @@ def _parse_linestyle(style_name, allow_false=False):
19281928
x_scl = np.ma.masked_where(scale_mask, resp.real)
19291929
y_scl = np.ma.masked_where(scale_mask, resp.imag)
19301930
if x_scl.count() >= 1 and y_scl.count() >= 1:
1931-
out[idx] += plt.plot(
1931+
out[idx] += ax.plot(
19321932
x_scl * (1 + curve_offset),
19331933
y_scl * (1 + curve_offset),
19341934
primary_style[1], color=c, **kwargs)
@@ -1939,20 +1939,19 @@ def _parse_linestyle(style_name, allow_false=False):
19391939
x, y = resp.real.copy(), resp.imag.copy()
19401940
x[reg_mask] *= (1 + curve_offset[reg_mask])
19411941
y[reg_mask] *= (1 + curve_offset[reg_mask])
1942-
p = plt.plot(x, y, linestyle='None', color=c)
1942+
p = ax.plot(x, y, linestyle='None', color=c)
19431943

19441944
# Add arrows
1945-
ax = plt.gca()
19461945
_add_arrows_to_line2D(
19471946
ax, p[0], arrow_pos, arrowstyle=arrow_style, dir=1)
19481947

19491948
# Plot the mirror image
19501949
if mirror_style is not False:
19511950
# Plot the regular and scaled segments
1952-
out[idx] += plt.plot(
1951+
out[idx] += ax.plot(
19531952
x_reg, -y_reg, mirror_style[0], color=c, **kwargs)
19541953
if x_scl.count() >= 1 and y_scl.count() >= 1:
1955-
out[idx] += plt.plot(
1954+
out[idx] += ax.plot(
19561955
x_scl * (1 - curve_offset),
19571956
-y_scl * (1 - curve_offset),
19581957
mirror_style[1], color=c, **kwargs)
@@ -1963,19 +1962,19 @@ def _parse_linestyle(style_name, allow_false=False):
19631962
x, y = resp.real.copy(), resp.imag.copy()
19641963
x[reg_mask] *= (1 - curve_offset[reg_mask])
19651964
y[reg_mask] *= (1 - curve_offset[reg_mask])
1966-
p = plt.plot(x, -y, linestyle='None', color=c, **kwargs)
1965+
p = ax.plot(x, -y, linestyle='None', color=c, **kwargs)
19671966
_add_arrows_to_line2D(
19681967
ax, p[0], arrow_pos, arrowstyle=arrow_style, dir=-1)
19691968
else:
19701969
out[idx] += [None, None]
19711970

19721971
# Mark the start of the curve
19731972
if start_marker:
1974-
plt.plot(resp[0].real, resp[0].imag, start_marker,
1973+
ax.plot(resp[0].real, resp[0].imag, start_marker,
19751974
color=c, markersize=start_marker_size)
19761975

19771976
# Mark the -1 point
1978-
plt.plot([-1], [0], 'r+')
1977+
ax.plot([-1], [0], 'r+')
19791978

19801979
#
19811980
# Draw circles for gain crossover and sensitivity functions
@@ -1987,16 +1986,16 @@ def _parse_linestyle(style_name, allow_false=False):
19871986

19881987
# Display the unit circle, to read gain crossover frequency
19891988
if unit_circle:
1990-
plt.plot(cos, sin, **config.defaults['nyquist.circle_style'])
1989+
ax.plot(cos, sin, **config.defaults['nyquist.circle_style'])
19911990

19921991
# Draw circles for given magnitudes of sensitivity
19931992
if ms_circles is not None:
19941993
for ms in ms_circles:
19951994
pos_x = -1 + (1/ms)*cos
19961995
pos_y = (1/ms)*sin
1997-
plt.plot(
1996+
ax.plot(
19981997
pos_x, pos_y, **config.defaults['nyquist.circle_style'])
1999-
plt.text(pos_x[label_pos], pos_y[label_pos], ms)
1998+
ax.text(pos_x[label_pos], pos_y[label_pos], ms)
20001999

20012000
# Draw circles for given magnitudes of complementary sensitivity
20022001
if mt_circles is not None:
@@ -2006,17 +2005,17 @@ def _parse_linestyle(style_name, allow_false=False):
20062005
rt = mt/(mt**2-1) # Mt radius
20072006
pos_x = ct+rt*cos
20082007
pos_y = rt*sin
2009-
plt.plot(
2008+
ax.plot(
20102009
pos_x, pos_y,
20112010
**config.defaults['nyquist.circle_style'])
2012-
plt.text(pos_x[label_pos], pos_y[label_pos], mt)
2011+
ax.text(pos_x[label_pos], pos_y[label_pos], mt)
20132012
else:
2014-
_, _, ymin, ymax = plt.axis()
2013+
_, _, ymin, ymax = ax.axis()
20152014
pos_y = np.linspace(ymin, ymax, 100)
2016-
plt.vlines(
2015+
ax.vlines(
20172016
-0.5, ymin=ymin, ymax=ymax,
20182017
**config.defaults['nyquist.circle_style'])
2019-
plt.text(-0.5, pos_y[label_pos], 1)
2018+
ax.text(-0.5, pos_y[label_pos], 1)
20202019

20212020
# Label the frequencies of the points on the Nyquist curve
20222021
if label_freq:
@@ -2039,7 +2038,7 @@ def _parse_linestyle(style_name, allow_false=False):
20392038
# np.round() is used because 0.99... appears
20402039
# instead of 1.0, and this would otherwise be
20412040
# truncated to 0.
2042-
plt.text(xpt, ypt, ' ' +
2041+
ax.text(xpt, ypt, ' ' +
20432042
str(int(np.round(f / 1000 ** pow1000, 0))) + ' ' +
20442043
prefix + 'Hz')
20452044

control/nichols.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,15 @@ def nichols_plot(
132132
out[idx] = ax_nichols.plot(x, y, *fmt, label=label_, **kwargs)
133133

134134
# Label the plot axes
135-
plt.xlabel('Phase [deg]')
136-
plt.ylabel('Magnitude [dB]')
135+
ax_nichols.set_xlabel('Phase [deg]')
136+
ax_nichols.set_ylabel('Magnitude [dB]')
137137

138138
# Mark the -180 point
139-
plt.plot([-180], [0], 'r+')
139+
ax_nichols.plot([-180], [0], 'r+')
140140

141141
# Add grid
142142
if grid:
143-
nichols_grid()
143+
nichols_grid(ax=ax_nichols)
144144

145145
# List of systems that are included in this plot
146146
lines, labels = _get_line_labels(ax_nichols)

control/tests/ctrlplot_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,15 @@ def test_plot_ax_processing(resp_fcn, plot_fcn):
243243
# No response function available; just plot the data
244244
plot_fcn(*args, **kwargs, **plot_kwargs, ax=ax)
245245

246+
# Make sure the plot ended up in the right place
247+
assert len(axs[0, 0].get_lines()) == 0 # upper left
248+
assert len(axs[0, 1].get_lines()) != 0 # top middle
249+
assert len(axs[1, 0].get_lines()) == 0 # lower left
250+
if resp_fcn != ct.gangof4_response:
251+
assert len(axs[1, 2].get_lines()) == 0 # lower right (normally empty)
252+
else:
253+
assert len(axs[1, 2].get_lines()) != 0 # gangof4 uses this axes
254+
246255
# Check to make sure original settings did not change
247256
assert fig._suptitle.get_text() == title
248257
assert fig._suptitle.get_fontsize() == titlesize

0 commit comments

Comments
 (0)
0