From d6d77dcdb41c660d4a3b5db3114703018d433495 Mon Sep 17 00:00:00 2001
From: bnavigator <code@bnavigator.de>
Date: Fri, 21 Aug 2020 21:03:06 +0200
Subject: [PATCH] returnScipySignalLTI for discrete systems

and add unit tests
---
 control/statesp.py                   | 45 +++++++++++++++++++-----
 control/tests/statesp_matrix_test.py | 52 ++++++++++++++++++++++++++++
 control/tests/xferfcn_test.py        | 41 ++++++++++++++++++++++
 control/xferfcn.py                   | 39 ++++++++++++++++-----
 4 files changed, 161 insertions(+), 16 deletions(-)

diff --git a/control/statesp.py b/control/statesp.py
index 5af916bf0..d23fbd7be 100644
--- a/control/statesp.py
+++ b/control/statesp.py
@@ -59,7 +59,8 @@
 from numpy.linalg import solve, eigvals, matrix_rank
 from numpy.linalg.linalg import LinAlgError
 import scipy as sp
-from scipy.signal import lti, cont2discrete
+from scipy.signal import cont2discrete
+from scipy.signal import StateSpace as signalStateSpace
 from warnings import warn
 from .lti import LTI, timebase, timebaseEqual, isdtime
 from . import config
@@ -200,7 +201,7 @@ def __init__(self, *args, **kw):
             raise ValueError("Needs 1 or 4 arguments; received %i." % len(args))
 
         # Process keyword arguments
-        remove_useless = kw.get('remove_useless', 
+        remove_useless = kw.get('remove_useless',
                                 config.defaults['statesp.remove_useless_states'])
 
         # Convert all matrices to standard form
@@ -798,9 +799,7 @@ def minreal(self, tol=0.0):
         else:
             return StateSpace(self)
 
-
-    # TODO: add discrete time check
-    def returnScipySignalLTI(self):
+    def returnScipySignalLTI(self, strict=True):
         """Return a list of a list of :class:`scipy.signal.lti` objects.
 
         For instance,
@@ -809,15 +808,45 @@ def returnScipySignalLTI(self):
         >>> out[3][5]
 
         is a :class:`scipy.signal.lti` object corresponding to the transfer
-        function from the 6th input to the 4th output."""
+        function from the 6th input to the 4th output.
+
+        Parameters
+        ----------
+        strict : bool, optional
+            True (default):
+                The timebase `ssobject.dt` cannot be None; it must
+                be continuous (0) or discrete (True or > 0).
+            False:
+              If `ssobject.dt` is None, continuous time
+              :class:`scipy.signal.lti` objects are returned.
+
+        Returns
+        -------
+        out : list of list of :class:`scipy.signal.StateSpace`
+            continuous time (inheriting from :class:`scipy.signal.lti`)
+            or discrete time (inheriting from :class:`scipy.signal.dlti`)
+            SISO objects
+        """
+        if strict and self.dt is None:
+            raise ValueError("with strict=True, dt cannot be None")
+
+        if self.dt:
+            kwdt = {'dt': self.dt}
+        else:
+            # scipy convention for continuous time lti systems: call without
+            # dt keyword argument
+            kwdt = {}
 
         # Preallocate the output.
         out = [[[] for _ in range(self.inputs)] for _ in range(self.outputs)]
 
         for i in range(self.outputs):
             for j in range(self.inputs):
-                out[i][j] = lti(asarray(self.A), asarray(self.B[:, j]),
-                                asarray(self.C[i, :]), self.D[i, j])
+                out[i][j] = signalStateSpace(asarray(self.A),
+                                             asarray(self.B[:, j:j + 1]),
+                                             asarray(self.C[i:i + 1, :]),
+                                             asarray(self.D[i:i + 1, j:j + 1]),
+                                             **kwdt)
 
         return out
 
diff --git a/control/tests/statesp_matrix_test.py b/control/tests/statesp_matrix_test.py
index 34a17f992..e7e91364a 100644
--- a/control/tests/statesp_matrix_test.py
+++ b/control/tests/statesp_matrix_test.py
@@ -5,6 +5,7 @@
 
 import unittest
 import numpy as np
+import pytest
 from numpy.linalg import solve
 from scipy.linalg import eigvals, block_diag
 from control import matlab
@@ -673,5 +674,56 @@ def test_sample_system_prewarping(self):
             decimal=4)
 
 
+class TestLTIConverter:
+    """Test returnScipySignalLTI method"""
+
+    @pytest.fixture
+    def mimoss(self, request):
+        """Test system with various dt values"""
+        n = 5
+        m = 3
+        p = 2
+        bx, bu = np.mgrid[1:n + 1, 1:m + 1]
+        cy, cx = np.mgrid[1:p + 1, 1:n + 1]
+        dy, du = np.mgrid[1:p + 1, 1:m + 1]
+        return StateSpace(np.eye(5) + np.eye(5, 5, 1),
+                          bx * bu,
+                          cy * cx,
+                          dy * du,
+                          request.param)
+
+    @pytest.mark.parametrize("mimoss",
+                             [None,
+                              0,
+                              0.1,
+                              1,
+                              True],
+                             indirect=True)
+    def test_returnScipySignalLTI(self, mimoss):
+        """Test returnScipySignalLTI method with strict=False"""
+        sslti = mimoss.returnScipySignalLTI(strict=False)
+        for i in range(mimoss.outputs):
+            for j in range(mimoss.inputs):
+                np.testing.assert_allclose(sslti[i][j].A, mimoss.A)
+                np.testing.assert_allclose(sslti[i][j].B, mimoss.B[:,
+                                                                   j:j + 1])
+                np.testing.assert_allclose(sslti[i][j].C, mimoss.C[i:i + 1,
+                                                                   :])
+                np.testing.assert_allclose(sslti[i][j].D, mimoss.D[i:i + 1,
+                                                                   j:j + 1])
+                if mimoss.dt == 0:
+                    assert sslti[i][j].dt is None
+                else:
+                    assert sslti[i][j].dt == mimoss.dt
+
+    @pytest.mark.parametrize("mimoss", [None], indirect=True)
+    def test_returnScipySignalLTI_error(self, mimoss):
+        """Test returnScipySignalLTI method with dt=None and strict=True"""
+        with pytest.raises(ValueError):
+            mimoss.returnScipySignalLTI()
+        with pytest.raises(ValueError):
+            mimoss.returnScipySignalLTI(strict=True)
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/control/tests/xferfcn_test.py b/control/tests/xferfcn_test.py
index 02e6c2b37..17e602090 100644
--- a/control/tests/xferfcn_test.py
+++ b/control/tests/xferfcn_test.py
@@ -4,6 +4,8 @@
 # RMM, 30 Mar 2011 (based on TestXferFcn from v0.4a)
 
 import unittest
+import pytest
+
 import sys as pysys
 import numpy as np
 from control.statesp import StateSpace, _convertToStateSpace, rss
@@ -934,5 +936,44 @@ def test_sample_system_prewarping(self):
             decimal=4)
 
 
+class TestLTIConverter:
+    """Test returnScipySignalLTI method"""
+
+    @pytest.fixture
+    def mimotf(self, request):
+        """Test system with various dt values"""
+        return TransferFunction([[[11], [12], [13]],
+                                 [[21], [22], [23]]],
+                                [[[1, -1]] * 3] * 2,
+                                request.param)
+
+    @pytest.mark.parametrize("mimotf",
+                             [None,
+                              0,
+                              0.1,
+                              1,
+                              True],
+                             indirect=True)
+    def test_returnScipySignalLTI(self, mimotf):
+        """Test returnScipySignalLTI method with strict=False"""
+        sslti = mimotf.returnScipySignalLTI(strict=False)
+        for i in range(2):
+            for j in range(3):
+                np.testing.assert_allclose(sslti[i][j].num, mimotf.num[i][j])
+                np.testing.assert_allclose(sslti[i][j].den, mimotf.den[i][j])
+                if mimotf.dt == 0:
+                    assert sslti[i][j].dt is None
+                else:
+                    assert sslti[i][j].dt == mimotf.dt
+
+    @pytest.mark.parametrize("mimotf", [None], indirect=True)
+    def test_returnScipySignalLTI_error(self, mimotf):
+        """Test returnScipySignalLTI method with dt=None and strict=True"""
+        with pytest.raises(ValueError):
+            mimotf.returnScipySignalLTI()
+        with pytest.raises(ValueError):
+            mimotf.returnScipySignalLTI(strict=True)
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/control/xferfcn.py b/control/xferfcn.py
index 1cba50bd7..4077080e3 100644
--- a/control/xferfcn.py
+++ b/control/xferfcn.py
@@ -57,7 +57,8 @@
     polyadd, polymul, polyval, roots, sqrt, zeros, squeeze, exp, pi, \
     where, delete, real, poly, nonzero
 import scipy as sp
-from scipy.signal import lti, tf2zpk, zpk2tf, cont2discrete
+from scipy.signal import tf2zpk, zpk2tf, cont2discrete
+from scipy.signal import TransferFunction as signalTransferFunction
 from copy import deepcopy
 from warnings import warn
 from itertools import chain
@@ -801,7 +802,7 @@ def minreal(self, tol=None):
         # end result
         return TransferFunction(num, den, self.dt)
 
-    def returnScipySignalLTI(self):
+    def returnScipySignalLTI(self, strict=True):
         """Return a list of a list of :class:`scipy.signal.lti` objects.
 
         For instance,
@@ -809,22 +810,44 @@ def returnScipySignalLTI(self):
         >>> out = tfobject.returnScipySignalLTI()
         >>> out[3][5]
 
-        is a class:`scipy.signal.lti` object corresponding to the
+        is a :class:`scipy.signal.lti` object corresponding to the
         transfer function from the 6th input to the 4th output.
 
+        Parameters
+        ----------
+        strict : bool, optional
+            True (default):
+                The timebase `tfobject.dt` cannot be None; it must be
+                continuous (0) or discrete (True or > 0).
+            False:
+                if `tfobject.dt` is None, continuous time
+                :class:`scipy.signal.lti`objects are returned
+
+        Returns
+        -------
+        out : list of list of :class:`scipy.signal.TransferFunction`
+            continuous time (inheriting from :class:`scipy.signal.lti`)
+            or discrete time (inheriting from :class:`scipy.signal.dlti`)
+            SISO objects
         """
+        if strict and self.dt is None:
+            raise ValueError("with strict=True, dt cannot be None")
 
-        # TODO: implement for discrete time systems
-        if self.dt != 0 and self.dt is not None:
-            raise NotImplementedError("Function not \
-                    implemented in discrete time")
+        if self.dt:
+            kwdt = {'dt': self.dt}
+        else:
+            # scipy convention for continuous time lti systems: call without
+            # dt keyword argument
+            kwdt = {}
 
         # Preallocate the output.
         out = [[[] for j in range(self.inputs)] for i in range(self.outputs)]
 
         for i in range(self.outputs):
             for j in range(self.inputs):
-                out[i][j] = lti(self.num[i][j], self.den[i][j])
+                out[i][j] = signalTransferFunction(self.num[i][j],
+                                                   self.den[i][j],
+                                                   **kwdt)
 
         return out