10000 Merge pull request #794 from murrayrm/mutable_default_args-13Nov2022 · python-control/python-control@2a799e9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2a799e9

Browse files
authored
Merge pull request #794 from murrayrm/mutable_default_args-13Nov2022
check for and fix mutable keyword defaults
2 parents 2dc409b + 4968cb3 commit 2a799e9

File tree

4 files changed

+78
-14
lines changed

control/flatsys/flatsys.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(self,
142142
forward, reverse, # flat system
143143
updfcn=None, outfcn=None, # I/O system
144144
inputs=None, outputs=None,
145-
states=None, params={}, dt=None, name=None):
145+
states=None, params=None, dt=None, name=None):
146146
"""Create a differentially flat I/O system.
147147
148148
The FlatIOSystem constructor is used to create an input/output system
@@ -171,7 +171,7 @@ def __str__(self):
171171
+ f"Forward: {self.forward}\n" \
172172
+ f"Reverse: {self.reverse}"
173173

174-
def forward(self, x, u, params={}):
174+
def forward(self, x, u, params=None):
175175

176176
"""Compute the flat flag given the states and input.
177177
@@ -200,7 +200,7 @@ def forward(self, x, u, params={}):
200200
"""
201201
raise NotImplementedError("internal error; forward method not defined")
202202

203-
def reverse(self, zflag, params={}):
203+
def reverse(self, zflag, params=None):
204204
"""Compute the states and input given the flat flag.
205205
206206
Parameters
@@ -224,18 +224,18 @@ def reverse(self, zflag, params={}):
224224
"""
225225
raise NotImplementedError("internal error; reverse method not defined")
226226

227-
def _flat_updfcn(self, t, x, u, params={}):
227+
def _flat_updfcn(self, t, x, u, params=None):
228228
# TODO: implement state space update using flat coordinates
229229
raise NotImplementedError("update function for flat system not given")
230230

231-
def _flat_outfcn(self, t, x, u, params={}):
231+
def _flat_outfcn(self, t, x, u, params=None):
232232
# Return the flat output
233233
zflag = self.forward(x, u, params)
234234
return np.array([zflag[i][0] for i in range(len(zflag))])
235235

236236

237237
# Utility function to compute flag matrix given a basis
238-
def _basis_flag_matrix(sys, basis, flag, t, params={}):
238+
def _basis_flag_matrix(sys, basis, flag, t):
239239
"""Compute the matrix of basis functions and their derivatives
240240
241241
This function computes the matrix ``M`` that is used to solve for the

control/flatsys/linflat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ def reverse(self, zflag, params):
142142
return np.reshape(x, self.nstates), np.reshape(u, self.ninputs)
143143

144144
# Update function
145-
def _rhs(self, t, x, u, params={}):
145+
def _rhs(self, t, x, u):
146146
# Use LinearIOSystem._rhs instead of default (MRO) NonlinearIOSystem
147147
return LinearIOSystem._rhs(self, t, x, u)
148148

149149
# output function
150-
def _out(self, t, x, u, params={}):
150+
def _out(self, t, x, u):
151151
# Use LinearIOSystem._out instead of default (MRO) NonlinearIOSystem
152152
return LinearIOSystem._out(self, t, x, u)

control/iosys.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,7 +1584,7 @@ def __init__(self, io_sys, ss_sys=None):
15841584
def input_output_response(
15851585
sys, T, U=0., X0=0, params=None,
15861586
transpose=False, return_x=False, squeeze=None,
1587-
solve_ivp_kwargs={}, t_eval='T', **kwargs):
1587+
solve_ivp_kwargs=None, t_eval='T', **kwargs):
15881588
"""Compute the output response of a system to a given input.
15891589
15901590
Simulate a dynamical system with a given input and return its output
@@ -1650,7 +1650,7 @@ def input_output_response(
16501650
solve_ivp_method : str, optional
16511651
Set the method used by :func:`scipy.integrate.solve_ivp`. Defaults
16521652
to 'RK45'.
1653-
solve_ivp_kwargs : str, optional
1653+
solve_ivp_kwargs : dict, optional
16541654
Pass additional keywords to :func:`scipy.integrate.solve_ivp`.
16551655
16561656
Raises
@@ -1676,6 +1676,7 @@ def input_output_response(
16761676
#
16771677

16781678
# Figure out the method to be used
1679+
solve_ivp_kwargs = solve_ivp_kwargs.copy() if solve_ivp_kwargs else {}
16791680
if kwargs.get('solve_ivp_method', None):
16801681
if kwargs.get('method', None):
16811682
raise ValueError("ivp_method specified more than once")

control/tests/kwargs_test.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def test_kwarg_search(module, prefix):
3838
# Skip anything that isn't part of the control package
3939
continue
4040

41+
# Look for classes and then check member functions
42+
if inspect.isclass(obj):
43+
test_kwarg_search(obj, prefix + obj.__name__ + '.')
44+
4145
# Only look for functions with keyword arguments
4246
if not inspect.isfunction(obj):
4347
continue
@@ -70,10 +74,6 @@ def test_kwarg_search(module, prefix):
7074
f"'unrecognized keyword' not found in unit test "
7175
f"for {name}")
7276

73-
# Look for classes and then check member functions
74-
if inspect.isclass(obj):
75-
test_kwarg_search(obj, prefix + obj.__name__ + '.')
76-
7777

7878
@pytest.mark.parametrize(
7979
"function, nsssys, ntfsys, moreargs, kwargs",
@@ -201,3 +201,66 @@ def test_matplotlib_kwargs(function, nsysargs, moreargs, kwargs, mplcleanup):
201201
'TimeResponseData.__call__': trdata_test.test_response_copy,
202202
'TransferFunction.__init__': test_unrecognized_kwargs,
203203
}
204+
205+
#
206+
# Look for keywords with mutable defaults
207+
#
208+
# This test goes through every function and looks for signatures that have a
209+
# default value for a keyword that is mutable. An error is generated unless
210+
# the function is listed in the `mutable_ok` set (which should only be used
211+
# for cases were the code has been explicitly checked to make sure that the
212+
# value of the mutable is not modified in the code).
213+
#
214+
mutable_ok = { # initial and date
215+
control.flatsys.SystemTrajectory.__init__, # RMM, 18 Nov 2022
216+
control.freqplot._add_arrows_to_line2D, # RMM, 18 Nov 2022
217+
control.namedio._process_dt_keyword, # RMM, 13 Nov 2022
218+
control.namedio._process_namedio_keywords, # RMM, 18 Nov 2022
219+
control.optimal.OptimalControlProblem.__init__, # RMM, 18 Nov 2022
220+
control.optimal.solve_ocp, # RMM, 18 Nov 2022
221+
control.optimal.create_mpc_iosystem, # RMM, 18 Nov 2022
222+
}
223+
224+
@pytest.mark.parametrize("module", [control, control.flatsys])
225+
def test_mutable_defaults(module, recurse=True):
226+
# Look through every object in the package
227+
for name, obj in inspect.getmembers(module):
228+
# Skip anything that is outside of this module
229+
if inspect.getmodule(obj) is not None and \
230+
not inspect.getmodule(obj).__name__.startswith('control'):
231+
# Skip anything that isn't part of the control package
232+
continue
233+
234+
# Look for classes and then check member functions
235+
if inspect.isclass(obj):
236+
test_mutable_defaults(obj, True)
237+
238+
# Look for modules and check for internal functions (w/ no recursion)
239+
if inspect.ismodule(obj) and recurse:
240+
test_mutable_defaults(obj, False)
241+
242+
# Only look at functions and skip any that are marked as OK
243+
if not inspect.isfunction(obj) or obj in mutable_ok:
244+
continue
245+
246+
# Get the signature for the function
247+
sig = inspect.signature(obj)
248+
249+
# Skip anything that is inherited
250+
if inspect.isclass(module) and obj.__name__ not in module.__dict__:
251+
continue
252+
253+
# See if there is a variable keyword argument
254+
for argname, par in sig.parameters.items():
255+
if par.default is inspect._empty or \
256+
not par.kind == inspect.Parameter.KEYWORD_ONLY and \
257+
not par.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
258+
continue
259+
260+
# Check to see if the default value is mutable
261+
if par.default is not None and not \
262+
isinstance(par.default, (bool, int, float, tuple, str)):
263+
pytest.fail(
264+
f"function '{obj.__name__}' in module '{module.__name__}'"
265+
f" has mutable default for keyword '{par.name}'")
266+

0 commit comments

Comments
 (0)
0