8000 Merge pull request #595 from bnavigator/fix-594 · python-control/python-control@de87cc6 · GitHub
[go: up one dir, main page]

Skip to content

Commit de87cc6

Browse files
authored
Merge pull request #595 from bnavigator/fix-594
lti squeeze: ndarray.ndim == 0 is also a scalar
2 parents f1a9860 + 5646146 commit de87cc6

File tree

2 files changed

+40
-19
lines changed

2 files changed

+40
-19
lines changed

control/lti.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def _process_frequency_response(sys, omega, out, squeeze=None):
665665
if squeeze is None:
666666
squeeze = config.defaults['control.squeeze_frequency_response']
667667

668-
if not hasattr(omega, '__len__'):
668+
if np.asarray(omega).ndim < 1:
669669
# received a scalar x, squeeze down the array along last dim
670670
out = np.squeeze(out, axis=2)
671671

control/tests/lti_test.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import control as ct
88
from control import c2d, tf, tf2ss, NonlinearIOSystem
9-
from control.lti import (LTI, common_timebase, damp, dcgain, isctime, isdtime,
10-
issiso, pole, timebaseEqual, zero)
9+
from control.lti import (LTI, common_timebase, evalfr, damp, dcgain, isctime,
10+
isdtime, issiso, pole, timebaseEqual, zero)
1111
from control.tests.conftest import slycotonly
1212
from control.exception import slycot_check
1313

@@ -179,11 +179,20 @@ def test_isdtime(self, objfun, arg, dt, ref, strictref):
179179
[1, 1, 2, [0.1, 1, 10], None, (1, 2, 3)], # MISO
180180
[2, 1, 2, [0.1, 1, 10], True, (2, 3)],
181181
[3, 1, 2, [0.1, 1, 10], False, (1, 2, 3)],
182+
[1, 1, 2, 0.1, None, (1, 2)],
183+
[1, 1, 2, 0.1, True, (2,)],
184+
[1, 1, 2, 0.1, False, (1, 2)],
182185
[1, 2, 2, [0.1, 1, 10], None, (2, 2, 3)], # MIMO
183186
[2, 2, 2, [0.1, 1, 10], True, (2, 2, 3)],
184-
[3, 2, 2, [0.1, 1, 10], False, (2, 2, 3)]
187+
[3, 2, 2, [0.1, 1, 10], False, (2, 2, 3)],
188+
[1, 2, 2, 0.1, None, (2, 2)],
189+
[2, 2, 2, 0.1, True, (2, 2)],
190+
[3, 2, 2, 0.1, False, (2, 2)],
185191
])
186-
def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape):
192+
@pytest.mark.parametrize("omega_type", ["numpy", "native"])
193+
def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape,
194+
omega_type):
195+
"""Test correct behavior of frequencey response squeeze parameter."""
187196
# Create the system to be tested
188197
if fcn == ct.frd:
189198
sys = fcn(ct.rss(nstate, nout, ninp), [1e-2, 1e-1, 1, 1e1, 1e2])
@@ -193,15 +202,23 @@ def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape):
193202
else:
194203
sys = fcn(ct.rss(nstate, nout, ninp))
195204

196-
# Convert the frequency list to an array for easy of use
197-
isscalar = not hasattr(omega, '__len__')
198-
omega = np.array(omega)
205+
if omega_type == "numpy":
206+
omega = np.asarray(omega)
207+
isscalar = omega.ndim == 0
208+
# keep the ndarray type even for scalars
209+
s = np.asarray(omega * 1j)
210+
else:
211+
isscalar = not hasattr(omega, '__len__')
212+
if isscalar:
213+
s = omega*1J
214+
else:
215+
s = [w*1J for w in omega]
199216

200217
# Call the transfer function directly and make sure shape is correct
201-
assert sys(omega * 1j, squeeze=squeeze).shape == shape
218+
assert sys(s, squeeze=squeeze).shape == shape
202219

203220
# Make sure that evalfr also works as expected
204-
assert ct.evalfr(sys, omega * 1j, squeeze=squeeze).shape == shape
221+
assert ct.evalfr(sys, s, squeeze=squeeze).shape == shape
205222

206223
# Check frequency response
207224
mag, phase, _ = sys.frequency_response(omega, squeeze=squeeze)
@@ -216,22 +233,22 @@ def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape):
216233

217234
# Make sure the default shape lines up with squeeze=None case
218235
if squeeze is None:
219-
assert sys(omega * 1j).shape == shape
236+
assert sys(s).shape == shape
220237

221238
# Changing config.default to False should return 3D frequency response
222239
ct.config.set_defaults('control', squeeze_frequency_response=False)
223240
mag, phase, _ = sys.frequency_response(omega)
224241
if isscalar:
225242
assert mag.shape == (sys.noutputs, sys.ninputs, 1)
226243
assert phase.shape == (sys.noutputs, sys.ninputs, 1)
227-
assert sys(omega * 1j).shape == (sys.noutputs, sys.ninputs)
228-
assert ct.evalfr(sys, omega * 1j).shape == (sys.noutputs, sys.ninputs)
244+
assert sys(s).shape == (sys.noutputs, sys.ninputs)
245+
assert ct.evalfr(sys, s).shape == (sys.noutputs, sys.ninputs)
229246
else:
230247
assert mag.shape == (sys.noutputs, sys.ninputs, len(omega))
231248
assert phase.shape == (sys.noutputs, sys.ninputs, len(omega))
232-
assert sys(omega * 1j).shape == \
249+
assert sys(s).shape == \
233250
(sys.noutputs, sys.ninputs, len(omega))
234-
assert ct.evalfr(sys, omega * 1j).shape == \
251+
assert ct.evalfr(sys, s).shape == \
235252
(sys.noutputs, sys.ninputs, len(omega))
236253

237254
@pytest.mark.parametrize("fcn", [ct.ss, ct.tf, ct.frd, ct.ss2io])
@@ -243,13 +260,17 @@ def test_squeeze_exceptions(self, fcn):
243260

244261
with pytest.raises(ValueError, match="unknown squeeze value"):
245262
sys.frequency_response([1], squeeze=1)
246-
sys([1], squeeze='siso')
247-
evalfr(sys, [1], squeeze='siso')
263+
with pytest.raises(ValueError, match="unknown squeeze value"):
264+
sys([1j], squeeze='siso')
265+
with pytest.raises(ValueError, match="unknown squeeze value"):
266+
evalfr(sys, [1j], squeeze='siso')
248267

249268
with pytest.raises(ValueError, match="must be 1D"):
250269
sys.frequency_response([[0.1, 1], [1, 10]])
251-
sys([[0.1, 1], [1, 10]])
252-
evalfr(sys, [[0.1, 1], [1, 10]])
270+
with pytest.raises(ValueError, match="must be 1D"):
271+
sys([[0.1j, 1j], [1j, 10j]])
272+
with pytest.raises(ValueError, match="must be 1D"):
273+
evalfr(sys, [[0.1j, 1j], [1j, 10j]])
253274

254275
with pytest.warns(DeprecationWarning, match="LTI `inputs`"):
255276
ninputs = sys.inputs

0 commit comments

Comments
 (0)
0