8000 slight code refactoring to consolidate flag matrix computation · python-control/python-control@0c1d638 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0c1d638

Browse files
committed
slight code refactoring to consolidate flag matrix computation
1 parent 1fe6e86 commit 0c1d638

File tree

3 files changed

+69
-66
lines changed

3 files changed

+69
-66
lines changed

control/flatsys/flatsys.py

Lines changed: 53 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def __init__(self,
155155
if forward is not None: self.forward = forward
156156
if reverse is not None: self.reverse = reverse
157157

158+
# Save the length of the flat flag
159+
158160
def forward(self, x, u, params={}):
159161
"""Compute the flat flag given the states and input.
160162
@@ -217,10 +219,33 @@ def _flat_outfcn(self, t, x, u, params={}):
217219
return np.array(zflag[:][0])
218220

219221

222+
# Utility function to compute flag matrix given a basis
223+
def _basis_flag_matrix(sys, basis, flag, t, params={}):
224+
"""Compute the matrix of basis functions and their derivatives
225+
226+
This function computes the matrix ``M`` that is used to solve for the
227+
coefficients of the basis functions given the state and input. Each
228+
column of the matrix corresponds to a basis function and each row is a
229+
derivative, with the derivatives (flag) for each output stacked on top
230+
of each other.
231+
232+
"""
233+
flagshape = [len(f) for f in flag]
234+
M = np.zeros((sum(flagshape), basis.N * sys.ninputs))
235+
flag_off = 0
236+
coeff_off = 0
237+
for i, flag_len in enumerate(flagshape):
238+
for j, k in itertools.product(range(basis.N), range(flag_len)):
239+
M[flag_off + k, coeff_off + j] = basis.eval_deriv(j, k, t)
240+
flag_off += flag_len
241+
coeff_off += basis.N
242+
return M
243+
244+
220245
# Solve a point to point trajectory generation problem for a flat system
221246
def point_to_point(
222247
sys, timepts, x0=0, u0=0, xf=0, uf=0, T0=0, basis=None, cost=None,
223-
constraints=None, initial_guess=None, minimize_kwargs={}):
248+
constraints=None, initial_guess=None, minimize_kwargs={}, **kwargs):
224249
"""Compute trajectory between an initial and final conditions.
225250
226251
Compute a feasible trajectory for a differentially flat system between an
@@ -251,9 +276,9 @@ def point_to_point(
251276
252277
basis : :class:`~control.flatsys.BasisFamily` object, optional
253278
The basis functions to use for generating the trajectory. If not
254-
specified, the :class:`~control.flatsys.PolyFamily` basis family will be
255-
used, with the minimal number of elements required to find a feasible
256-
trajectory (twice the number of system states)
279+
specified, the :class:`~control.flatsys.PolyFamily` basis family
280+
will be used, with the minimal number of elements required to find a
281+
feasible trajectory (twice the number of system states)
257282
258283
cost : callable
259284
Function that returns the integral cost given the current state
@@ -287,6 +312,12 @@ def point_to_point(
287312
`eval()` function, we can be used to compute the value of the state
288313
and input and a given time t.
289314
315+
Notes
316+
-----
317+
Additional keyword parameters can be used to fine tune the behavior of
318+
the underlying optimization function. See `minimize_*` keywords in
319+
:func:`OptimalControlProblem` for more information.
320+
290321
"""
291322
#
292323
# Make sure the problem is one that we can handle
@@ -296,7 +327,7 @@ def point_to_point(
296327
u0 = _check_convert_array(u0, [(sys.ninputs,), (sys.ninputs, 1)],
297328
'Initial input: ', squeeze=True)
298329
xf = _check_convert_array(xf, [(sys.nstates,), (sys.nstates, 1)],
299-
'Final state: ' , squeeze=True)
330+
'Final state: ', squeeze=True)
300331
uf = _check_convert_array(uf, [(sys.ninputs,), (sys.ninputs, 1)],
301332
'Final input: ', squeeze=True)
302333

@@ -305,6 +336,12 @@ def point_to_point(
305336
Tf = timepts[-1]
306337
T0 = timepts[0] if len(timepts) > 1 else T0
307338

339+
# Process keyword arguments
340+
minimize_kwargs['method'] = kwargs.pop('minimize_method', None)
341+
minimize_kwargs['options'] = kwargs.pop('minimize_options', {})
342+
if kwargs:
343+
raise TypeError("unrecognized keywords: ", str(kwargs))
344+
308345
#
309346
# Determine the basis function set to use and make sure it is big enough
310347
#
@@ -328,8 +365,7 @@ def point_to_point(
328365
# We need to compute the output "flag": [z(t), z'(t), z''(t), ...]
329366
# and then evaluate this at the initial and final condition.
330367
#
331-
# TODO: should be able to represent flag variables as 1D arrays
332-
# TODO: need inputs to fully define the flag
368+
333369
zflag_T0 = sys.forward(x0, u0)
334370
zflag_Tf = sys.forward(xf, uf)
335371

@@ -340,41 +376,13 @@ def point_to_point(
340376
# essentially amounts to evaluating the basis functions and their
341377
# derivatives at the initial and final conditions.
342378

343-
# Figure out the size of the problem we are solving
344-
flag_tot = np.sum([len(zflag_T0[i]) for i in range(sys.ninputs)])
379+
# Compute the flags for the initial and final states
380+
M_T0 = _basis_flag_matrix(sys, basis, zflag_T0, T0)
381+
M_Tf = _basis_flag_matrix(sys, basis, zflag_Tf, Tf)
345382

346-
# Start by creating an empty matrix that we can fill up
347-
# TODO: allow a different number of basis elements for each flat output
348-
M = np.zeros((2 * flag_tot, basis.N * sys.ninputs))
349-
350-
# Now fill in the rows for the initial and final states
351-
# TODO: vectorize
352-
flag_off = 0
353-
coeff_off = 0
354-
355-
for i in range(sys.ninputs):
356-
flag_len = len(zflag_T0[i])
357-
for j in range(basis.N):
358-
for k in range(flag_len):
359-
M[flag_off + k, coeff_off + j] = basis.eval_deriv(j, k, T0)
360-
M[flag_tot + flag_off + k, coeff_off + j] = \
361-
basis.eval_deriv(j, k, Tf)
362-
flag_off += flag_len
363-
coeff_off += basis.N
364-
365-
# Create an empty matrix that we can fill up
366-
Z = np.zeros(2 * flag_tot)
367-
368-
# Compute the flag vector to use for the right hand side by
369-
# stacking up the flags for each input
370-
# TODO: make this more pythonic
371-
flag_off = 0
372-
for i in range(sys.ninputs):
373-
flag_len = len(zflag_T0[i])
374-
for j in range(flag_len):
375-
Z[flag_off + j] = zflag_T0[i][j]
376-
Z[flag_tot + flag_off + j] = zflag_Tf[i][j]
377-
flag_off += flag_len
383+
# Stack the initial and final matrix/flag for the point to point problem
384+
M = np.vstack([M_T0, M_Tf])
385+
Z = np.hstack([np.hstack(zflag_T0), np.hstack(zflag_Tf)])
378386

379387
#
380388
# Solve for the coefficients of the flat outputs
@@ -404,17 +412,7 @@ def traj_cost(null_coeffs):
404412
# Evaluate the costs at the listed time points
405413
costval = 0
406414
for t in timepts:
407-
M_t = np.zeros((flag_tot, basis.N * sys.ninputs))
408-
flag_off = 0
409-
coeff_off = 0
410-
for i in range(sys.ninputs):
411-
flag_len = len(zflag_T0[i])
412-
for j, k in itertools.product(
413-
range(basis.N), range(flag_len)):
414-
M_t[flag_off + k, coeff_off + j] = \
415-
basis.eval_deriv(j, k, t)
416-
flag_off += flag_len
417-
coeff_off += basis.N
415+
M_t = _basis_flag_matrix(sys, basis, zflag_T0, t)
418416

419417
# Compute flag at this time point
420418
zflag = (M_t @ coeffs).reshape(sys.ninputs, -1)
@@ -452,17 +450,7 @@ def traj_const(null_coeffs):
452450
values = []
453451
for i, t in enumerate(timepts):
454452
# Calculate the states and inputs for the flat output
455-
M_t = np.zeros((flag_tot, basis.N * sys.ninputs))
456-
flag_off = 0
457-
coeff_off = 0
458-
for i in range(sys.ninputs):
459-
flag_len = len(zflag_T0[i])
460-
for j, k in itertools.product(
461-
range(basis.N), range(flag_len)):
462-
M_t[flag_off + k, coeff_off + j] = \
463-
basis.eval_deriv(j, k, t)
464-
flag_off += flag_len
465-
coeff_off += basis.N
453+
M_t = _basis_flag_matrix(sys, basis, zflag_T0, t)
466454

467455
# Compute flag at this time point
468456
zflag = (M_t @ coeffs).reshape(sys.ninputs, -1)
@@ -501,7 +489,7 @@ def traj_const(null_coeffs):
501489

502490
# Process the initial condition
503491
if initial_guess is None:
504-
initial_guess = np.zeros(basis.N * sys.ninputs - 2 * flag_tot)
492+
initial_guess = np.zeros(M.shape[1] - M.shape[0])
505493
else:
506494
raise NotImplementedError("Initial guess not yet implemented.")
507495

@@ -514,7 +502,7 @@ def traj_const(null_coeffs):
514502
else:
515503
raise RuntimeError(
516504
"Unable to solve optimal control problem\n" +
517-
"scipy.optimize.minimize returned " + res.message)
505+
"scipy.optimize.minimize returned " + res.message)
518506

519507
#
520508
# Transform the trajectory from flat outputs to states and inputs

control/tests/flatsys_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,18 @@ def test_point_to_point_errors(self):
331331
traj = fs.point_to_point(
332332
flat_sys, timepts, x0, u0, xf, uf, constraints=constraint,
333333
basis=fs.PolyFamily(8))
334+
335+
# Method arguments, parameters
336+
traj_method = fs.point_to_point(
337+
flat_sys, timepts, x0, u0, xf, uf, cost=cost_fcn,
338+
basis=fs.PolyFamily(8), minimize_method='slsqp')
339+
traj_kwarg = fs.point_to_point(
340+
flat_sys, timepts, x0, u0, xf, uf, cost=cost_fcn,
341+
basis=fs.PolyFamily(8), minimize_kwargs={'method': 'slsqp'})
342+
np.testing.assert_almost_equal(
343+
traj_method.eval(timepts)[0], traj_kwarg.eval(timepts)[0])
344+
345+
# Unrecognized keywords
346+
with pytest.raises(TypeError, match="unrecognized keyword"):
347+
traj_method = fs.point_to_point(
348+
flat_sys, timepts, x0, u0, xf, uf, solve_ivp_method=None)

doc/flatsys.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
. _flatsys-module:
1+
.. _flatsys-module:
22

33
***************************
44
Differentially flat systems

0 commit comments

Comments
 (0)
0