8000 Merge pull request #829 from murrayrm/stochsys_ctime-29Dec2022 · python-control/python-control@a6e85c4 · GitHub
[go: up one dir, main page]

Skip to content

Commit a6e85c4

Browse files
authored
Merge pull request #829 from murrayrm/stochsys_ctime-29Dec2022
continuous time system support for create_estimator_iosystem
2 parents c3488cd + da1f162 commit a6e85c4

File tree

3 files changed

+130
-34
lines changed

3 files changed

+130
-34
lines changed

control/stochsys.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import scipy as sp
2121
from math import sqrt
2222

23-
from .iosys import InputOutputSystem, NonlinearIOSystem
23+
from .iosys import InputOutputSystem, LinearIOSystem, NonlinearIOSystem
2424
from .lti import LTI
2525
from .namedio import isctime, isdtime
2626
from .mateqn import care, dare, _check_shape
@@ -314,7 +314,13 @@ def create_estimator_iosystem(
314314
"""Create an I/O system implementing a linqear quadratic estimator
315315
316316
This function creates an input/output system that implements a
317-
state estimator of the form
317+
continuous time state estimator of the form
318+
319+
\dot xhat = A x + B u - L (C xhat - y)
320+
\dot P = A P + P A^T + F QN F^T - P C^T RN^{-1} C P
321+
L = P C^T RN^{-1}
322+
323+
or a discrete time state estimator of the form
318324
319325
8000 xhat[k + 1] = A x[k] + B u[k] - L (C xhat[k] - y[k])
320326
P[k + 1] = A P A^T + F QN F^T - A P C^T Reps^{-1} C P A
@@ -359,8 +365,9 @@ def create_estimator_iosystem(
359365
Returns
360366
-------
361367
estim : InputOutputSystem
362-
Input/output system representing the estimator. This system takes the
363-
system input and output and generates the estimated state.
368+
Input/output system representing the estimator. This system takes
369+
the system output y and input u and generates the estimated state
370+
xhat.
364371
365372
Notes
366373
-----
@@ -384,8 +391,8 @@ def create_estimator_iosystem(
384391
"""
385392

386393
# Make sure that we were passed an I/O system as an input
387-
if not isinstance(sys, InputOutputSystem):
388-
raise ControlArgument("Input system must be I/O system")
394+
if not isinstance(sys, LinearIOSystem):
395+
raise ControlArgument("Input system must be a linear I/O system")
389396

390397
# Extract the matrices that we need for easy reference
391398
A, B = sys.A, sys.B
@@ -409,7 +416,7 @@ def create_estimator_iosystem(
409416
# Initialize the covariance matrix
410417
if P0 is None:
411418
# Initalize P0 to the steady state value
412-
_, P0, _ = lqe(A, G, C, QN, RN)
419+
L0, P0, _ = lqe(A, G, C, QN, RN)
413420

414421
# Figure out the labels to use
415422
if isinstance(state_labels, str):
@@ -432,24 +439,54 @@ def create_estimator_iosystem(
432439
sensor_labels = [sensor_labels.format(i=i) for i in range(C.shape[0])]
433440

434441
if isctime(sys):
435-
raise NotImplementedError("continuous time not yet implemented")
436-
437-
else:
438442
# Create an I/O system for the state feedback gains
439443
# Note: reshape vector 9E88 s into column vectors for legacy np.matrix
444+
445+
R_inv = np.linalg.inv(RN)
446+
Reps_inv = C.T @ R_inv @ C
447+
448+
def _estim_update(t, x, u, params):
449+
# See if we are estimating or predicting
450+
correct = params.get('correct', True)
451+
452+
# Get the state of the estimator
453+
xhat = x[0:sys.nstates].reshape(-1, 1)
454+
P = x[sys.nstates:].reshape(sys.nstates, sys.nstates)
455+
456+
# Extract the inputs to the estimator
457+
y = u[0:C.shape[0]].reshape(-1, 1)
458+
u = u[C.shape[0]:].reshape(-1, 1)
459+
460+
# Compute the optimal gain
461+
L = P @ C.T @ R_inv
462+
463+
# Update the state estimate
464+
dxhat = A @ xhat + B @ u # prediction
465+
if correct:
466+
dxhat -= L @ (C @ xhat - y) # correction
467+
468+
# Update the covariance
469+
dP = A @ P + P @ A.T + G @ QN @ G.T
470+
if correct:
471+
dP -= P @ Reps_inv @ P
472+
473+
# Return the update
474+
return np.hstack([dxhat.reshape(-1), dP.reshape(-1)])
475+
476+
else:
440477
def _estim_update(t, x, u, params):
441478
# See if we are estimating or predicting
442479
correct = params.get('correct', True)
443480

444-
# Get the state of the estimator
481+
# Get the state of the estimator
445482
xhat = x[0:sys.nstates].reshape(-1, 1)
446483
P = x[sys.nstates:].reshape(sys.nstates, sys.nstates)
447484

448485
# Extract the inputs to the estimator
449486
y = u[0:C.shape[0]].reshape(-1, 1)
450487
u = u[C.shape[0]:].reshape(-1, 1)
451488

452-
# Compute the optimal again
489+
# Compute the optimal gain
453490
Reps_inv = np.linalg.inv(RN + C @ P @ C.T)
454491
L = A @ P @ C.T @ Reps_inv
455492

@@ -466,8 +503,8 @@ def _estim_update(t, x, u, params):
466503
# Return the update
467504
return np.hstack([dxhat.reshape(-1), dP.reshape(-1)])
468505

469-
def _estim_output(t, x, u, params):
470-
return x[0:sys.nstates]
506+
def _estim_output(t, x, u, params):
507+
return x[0:sys.nstates]
471508

472509
# Define the estimator system
473510
return NonlinearIOSystem(

control/tests/flatsys_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,17 @@ def test_kinematic_car_ocp(
212212
elif re.match("Iteration limit.*", traj_ocp.message) and \
213213
re.match(
214214
"conda ubuntu-3.* Generic", os.getenv('JOBNAME', '')) and \
215-
np.__version__ == '1.24.0':
215+
re.match("1.24.[01]", np.__version__):
216216
pytest.xfail("gh820: iteration limit exceeded")
217217

218218
else:
219219
# Dump out information to allow creation of an exception
220-
print("Platform: ", platform.platform())
221-
print("Python: ", platform.python_version())
220+
print("Message:", traj_ocp.message)
221+
print("Platform:", platform.platform())
222+
print("Python:", platform.python_version())
223+
print("NumPy version:", np.__version__)
222224
np.show_config()
223-
print("JOBNAME: ", os.getenv('JOBNAME'))
225+
print("JOBNAME:", os.getenv('JOBNAME'))
224226

225227
pytest.fail(
226228
"unknown failure; view output to identify configuration")

control/tests/stochsys_test.py

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import control as ct
99
from control import lqe, dlqe, rss, drss, tf, ss, ControlArgument, slycot_check
10+
from math import log, pi
1011

1112
# Utility function to check LQE answer
1213
def check_LQE(L, P, poles, G, QN, RN):
@@ -48,7 +49,7 @@ def test_lqe_call_format(cdlqe):
4849

4950
# Standard calling format
5051
Lref, Pref, Eref = cdlqe(sys.A, sys.B, sys.C, Q, R)
51-
52+
5253
# Call with system instead of matricees
5354
L, P, E = cdlqe(sys, Q, R)
5455
np.testing.assert_almost_equal(Lref, L)
@@ -58,15 +59,15 @@ def test_lqe_call_format(cdlqe):
5859
# Make sure we get an error if we specify N
5960
with pytest.raises(ct.ControlNotImplemented):
6061
L, P, E = cdlqe(sys, Q, R, N)
61-
62+
6263
# Inconsistent system dimensions
6364
with pytest.raises(ct.ControlDimension, match="Incompatible"):
6465
L, P, E = cdlqe(sys.A, sys.C, sys.B, Q, R)
65-
66+
6667
# Incorrect covariance matrix dimensions
6768
with pytest.raises(ct.ControlDimension, match="Incompatible"):
6869
L, P, E = cdlqe(sys.A, sys.B, sys.C, R, Q)
69-
70+
7071
# Too few input arguments
7172
with pytest.raises(ct.ControlArgument, match="not enough input"):
7273
L, P, E = cdlqe(sys.A, sys.C)
@@ -99,26 +100,26 @@ def test_lqe_discrete():
99100
np.testing.assert_almost_equal(K_csys, K_expl)
100101
np.testing.assert_almost_equal(S_csys, S_expl)
101102
np.testing.assert_almost_equal(E_csys, E_expl)
102-
103+
103104
# Calling lqe() with a discrete time system should call dlqe()
104105
K_lqe, S_lqe, E_lqe = ct.lqe(dsys, Q, R)
105106
K_dlqe, S_dlqe, E_dlqe = ct.dlqe(dsys, Q, R)
106107
np.testing.assert_almost_equal(K_lqe, K_dlqe)
107108
np.testing.assert_almost_equal(S_lqe, S_dlqe)
108109
np.testing.assert_almost_equal(E_lqe, E_dlqe)
109-
110+
110111
# Calling lqe() with no timebase should call lqe()
111112
asys = ct.ss(csys.A, csys.B, csys.C, csys.D, dt=None)
112113
K_asys, S_asys, E_asys = ct.lqe(asys, Q, R)
113114
K_expl, S_expl, E_expl = ct.lqe(csys.A, csys.B, csys.C, Q, R)
114115
np.testing.assert_almost_equal(K_asys, K_expl)
115116
np.testing.assert_almost_equal(S_asys, S_expl)
116117
np.testing.assert_almost_equal(E_asys, E_expl)
117-
118+
118119
# Calling dlqe() with a continuous time system should raise an error
119120
with pytest.raises(ControlArgument, match="called with a continuous"):
120121
K, S, E = ct.dlqe(csys, Q, R)
121-
122+
122123
def test_estimator_iosys():
123124
sys = ct.drss(4, 2, 2, strictly_proper=True)
124125

@@ -129,7 +130,7 @@ def test_estimator_iosys():
129130
QN = np.eye(sys.ninputs)
130131
RN = np.eye(sys.noutputs)
131132
estim = ct.create_estimator_iosystem(sys, QN, RN, P0)
132-
133+
133134
ctrl, clsys = ct.create_statefbk_iosystem(sys, K, estimator=estim)
134135

135136
# Extract the elements of the estimator
@@ -162,20 +163,76 @@ def test_estimator_iosys():
162163
np.testing.assert_almost_equal(cls.D, D_clchk)
163164

164165

166+
@pytest.mark.parametrize("sys_args", [
167+
([[-1]], [[1]], [[1]], 0), # scalar system
168+
([[-1, 0.1], [0, -2]], [[0], [1]], [[1, 0]], 0), # SISO, 2 state
169+
([[-1, 0.1], [0, -2]], [[1, 0], [0, 1]], [[1, 0]], 0), # 2i, 1o, 2s
170+
([[-1, 0.1, 0.1], [0, -2, 0], [0.1, 0, -3]], # 2i, 2o, 3s
171+
[[1, 0], [0, 0.1], [0, 1]],
172+
[[1, 0, 0.1], [0, 1, 0.1]], 0),
173+
])
174+
def test_estimator_iosys_ctime(sys_args):
175+
# Define the system we want to test
176+
sys = ct.ss(*sys_args)
177+
T = 10 * log(1e-2) / np.max(sys.poles().real)
178+
assert T > 0
179+
180+
# Create nonlinear version of the system to match integration methods
181+
nl_sys = ct.NonlinearIOSystem(
182+
lambda t, x, u, params : sys.A @ x + sys.B @ u,
183+
lambda t, x, u, params : sys.C @ x + sys.D @ u,
184+
inputs=sys.ninputs, outputs=sys.noutputs, states=sys.nstates)
185+
186+
# Define an initial condition, inputs (small, to avoid integration errors)
187+
timepts = np.linspace(0, T, 500)
188+
U = 2e-2 * np.array([np.sin(timepts + i*pi/3) for i in range(sys.ninputs)])
189+
X0 = np.ones(sys.nstates)
190+
191+
# Set up the parameters for the filter
192+
P0 = np.eye(sys.nstates)
193+
QN = np.eye(sys.ninputs)
194+
RN = np.eye(sys.noutputs)
195+
196+
# Construct the estimator
197+
estim = ct.create_estimator_iosystem(sys, QN, RN)
198+
199+
# Compute the system response and the optimal covariance
200+
sys_resp = ct.input_output_response(nl_sys, timepts, U, X0)
201+
_, Pf, _ = ct.lqe(sys, QN, RN)
202+
Pf = np.array(Pf) # convert from matrix, if needed
203+
204+
# Make sure that we converge to the optimal estimate
205+
estim_resp = ct.input_output_response(
206+
estim, timepts, [sys_resp.outputs, U], [0*X0, P0])
207+
np.testing.assert_allclose(
208+
estim_resp.states[0:sys.nstates, -1], sys_resp.states[:, -1],
209+
atol=1e-6, rtol=1e-3)
210+
np.testing.assert_allclose(
211+
estim_resp.states[sys.nstates:, -1], Pf.reshape(-1),
212+
atol=1e-6, rtol=1e-3)
213+
214+
# Make sure that optimal estimate is an eq pt
215+
ss_resp = ct.input_output_response(
216+
estim, timepts, [sys_resp.outputs, U], [X0, Pf])
217+
np.testing.assert_allclose(
218+
ss_resp.states[sys.nstates:],
219+
np.outer(Pf.reshape(-1), np.ones_like(timepts)),
220+
atol=1e-4, rtol=1e-2)
221+
np.testing.assert_allclose(
222+
ss_resp.states[0:sys.nstates], sys_resp.states,
223+
atol=1e-4, rtol=1e-2)
224+
225+
165226
def test_estimator_errors():
166227
sys = ct.drss(4, 2, 2, strictly_proper=True)
167228
P0 = np.eye(sys.nstates)
168229
QN = np.eye(sys.ninputs)
169230
RN = np.eye(sys.noutputs)
170231

171-
with pytest.raises(ct.ControlArgument, match="Input system must be I/O"):
232+
with pytest.raises(ct.ControlArgument, match=".* system must be a linear"):
172233
sys_tf = ct.tf([1], [1, 1], dt=True)
173234
estim = ct.create_estimator_iosystem(sys_tf, QN, RN)
174-
175-
with pytest.raises(NotImplementedError, match="continuous time not"):
176-
sys_ct = ct.rss(4, 2, 2, strictly_proper=True)
177-
estim = ct.create_estimator_iosystem(sys_ct, QN, RN)
178-
235+
179236
with pytest.raises(ValueError, match="output must be full state"):
180237
C = np.eye(2, 4)
181238
estim = ct.create_estimator_iosystem(sys, QN, RN, C=C)
@@ -246,7 +303,7 @@ def test_correlation():
246303
# Try passing a second argument
247304
tau, Rneg = ct.correlation(T, V, -V)
248305
np.testing.assert_equal(Rtau, -Rneg)
249-
306+
250307
# Test error conditions
251308
with pytest.raises(ValueError, match="Time vector T must be 1D"):
252309
tau, Rtau = ct.correlation(V, V)

0 commit comments

Comments
 (0)
0