8000 Refactor symbolic support · RPellowski/spatialmath-python@078ed0e · GitHub
[go: up one dir, main page]

Skip to content

Commit 078ed0e

Browse files
committed
Refactor symbolic support
1 parent b496c82 commit 078ed0e

File tree

4 files changed

+97
-76
lines changed

4 files changed

+97
-76
lines changed

spatialmath/base/symbolic.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import math
2+
3+
try: # pragma: no cover
4+
# print('Using SymPy')
5+
import sympy
6+
from sympy import S
7+
8+
_symbolics = True
9+
symtype = (sympy.Expr,)
10+
11+
except ImportError:
12+
_symbolics = False
13+
symtype = ()
14+
15+
16+
# ---------------------------------------------------------------------------------------#
17+
18+
def symbol(name, real=True):
19+
return sympy.symbols(name, real=real)
20+
21+
def issymbol(var):
22+
if _symbolics:
23+
if isinstance(var, (list, tuple)):
24+
return any([isinstance(x, symtype) for x in var])
25+
else:
26+
return isinstance(var, symtype)
27+
else:
28+
return False
29+
30+
def sin(theta):
31+
if issymbol(theta):
32+
return sympy.sin(theta)
33+
else:
34+
return math.sin(theta)
35+
36+
def cos(theta):
37+
if issymbol(theta):
38+
return sympy.cos(theta)
39+
else:
40+
return math.cos(theta)
41+
42+
def sqrt(v):
43+
if issymbol(v):
44+
return sympy.sqrt(v)
45+
else:
46+
return math.sqrt(v)
47+
48+
def zero():
49+
return S.Zero
50+
51+
def one():
52+
return S.One
53+
54+
def negative_one():
55+
return S.NegativeOne
56+
57+
def zero():
58+
return S.Zero
59+
60+
def pi():
61+
return S.Pi

spatialmath/base/transforms2d.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -48,38 +48,11 @@
4848
from spatialmath.base import vectors as vec
4949
from spatialmath.base import transformsNd as trn
5050
from spatialmath.base import animate
51-
52-
53-
try: # pragma: no cover
54-
#print('Using SymPy')
55-
import sympy as sym
56-
57-
def _issymbol(x):
58-
return isinstance(x, sym.Symbol)
59-
except ImportError:
60-
def _issymbol(x): # pylint: disable=unused-argument
61-
return False
51+
import spatialmath.base.symbolic as sym
6252

6353
_eps = np.finfo(np.float64).eps
6454

6555

66-
# ---------------------------------------------------------------------------------------#
67-
68-
69-
def _cos(theta):
70-
if _issymbol(theta):
71-
return sym.cos(theta)
72-
else:
73-
return math.cos(theta)
74-
75-
76-
def _sin(theta):
77-
if _issymbol(theta):
78-
return sym.sin(theta)
79-
else:
80-
return math.sin(theta)
81-
82-
8356
# ---------------------------------------------------------------------------------------#
8457
def rot2(theta, unit='rad'):
8558
"""
@@ -96,13 +69,11 @@ def rot2(theta, unit='rad'):
9669
- ``ROT2(THETA, 'deg')`` as above but THETA is in degrees.
9770
"""
9871
theta = argcheck.getunit(theta, unit)
99-
ct = _cos(theta)
100-
st = _sin(theta)
72+
ct = sym.cos(theta)
73+
st = sym.sin(theta)
10174
R = np.array([
10275
[ct, -st],
10376
[st, ct]])
104-
if not _issymbol(theta):
105-
R = R.round(15)
10677
return R
10778

10879

spatialmath/base/transforms3d.py

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -53,36 +53,13 @@
5353
from spatialmath.base import transformsNd as trn
5454
from spatialmath.base import quaternions as quat
5555
from spatialmath.base import animate
56-
57-
try: # pragma: no cover
58-
# print('Using SymPy')
59-
import sympy as sym
60-
61-
def _issymbol(x):
62-
return isinstance(x, sym.Symbol)
63-
except ImportError:
64-
def _issymbol(x): # pylint: disable=unused-argument
65-
return False
56+
import spatialmath.base.symbolic as sym
6657

6758
_eps = np.finfo(np.float64).eps
6859

6960
# ---------------------------------------------------------------------------------------#
7061

7162

72-
def _cos(theta):
73-
if _issymbol(theta):
74-
return sym.cos(theta)
75-
else:
76-
return math.cos(theta)
77-
78-
79-
def _sin(theta):
80-
if _issymbol(theta):
81-
return sym.sin(theta)
82-
else:
83-
return math.sin(theta)
84-
85-
8663
def rotx(theta, unit="rad"):
8764
"""
8865
Create SO(3) rotation about X-axis
@@ -102,8 +79,8 @@ def rotx(theta, unit="rad"):
10279
"""
10380

10481
theta = argcheck.getunit(theta, unit)
105-
ct = _cos(theta)
106-
st = _sin(theta)
82+
ct = sym.cos(theta)
83+
st = sym.sin(theta)
10784
R = np.array([
10885
[1, 0, 0],
10986
[0, ct, -st],
@@ -131,8 +108,8 @@ def roty(theta, unit="rad"):
131108
"""
132109

133110
theta = argcheck.getunit(theta, unit)
134-
ct = _cos(theta)
135-
st = _sin(theta)
111+
ct = sym.cos(theta)
112+
st = sym.sin(theta)
136113
R = np.array([
137114
[ct, 0, st],
138115
[0, 1, 0],
@@ -159,8 +136,8 @@ def rotz(theta, unit="rad"):
159136
:seealso: :func:`~yrotz`
160137
"""
161138
theta = argcheck.getunit(theta, unit)
162-
ct = _cos(theta)
163-
st = _sin(theta)
139+
ct = sym.cos(theta)
140+
st = sym.sin(theta)
164141
R = np.array([
165142
[ct, -st, 0],
166143
[st, ct, 0],
@@ -1315,7 +1292,7 @@ def tr2jac(T, samebody=False):
13151292
return np.block([[R.T, Z], [Z, R.T]])
13161293

13171294

1318-
def trprint(T, orient='rpy/zyx', label=None, file=sys.stdout, fmt='{:8.2g}', unit='deg'):
1295+
def trprint(T, orient='rpy/zyx', label=None, file=sys.stdout, fmt='{:8.2g}', degsym=True, unit='deg'):
13191296
"""
13201297
Compact display of SO(3) or SE(3) matrices
13211298
@@ -1394,19 +1371,26 @@ def trprint(T, orient='rpy/zyx', label=None, file=sys.stdout, fmt='{:8.2g}', uni
13941371
else:
13951372
seq = None
13961373
angles = tr2rpy(T, order=seq, unit=unit)
1397-
s += ' {} = {} {}'.format(orient, _vec2s(fmt, angles), unit)
1374+
if degsym and unit == "deg":
1375+
fmt += "\u00b0"
1376+
s += ' {} = {}'.format(orient, _vec2s(fmt, angles))
13981377

13991378
elif a[0].startswith('eul'):
14001379
angles = tr2eul(T, unit)
1401-
s += ' eul = {} {}'.format(_vec2s(fmt, angles), unit)
1380+
if degsym and unit == "deg":
1381+
fmt += "\u00b0"
1382+
s += ' eul = {}'.format(_vec2s(fmt, angles))
14021383

14031384
elif a[0] == 'angvec':
14041385
# as a vector and angle
14051386
(theta, v) = tr2angvec(T, unit)
14061387
if theta == 0:
14071388
s += ' R = nil'
14081389
else:
1409-
s += ' angvec = ({} {} | {})'.format(fmt.format(theta), unit, _vec2s(fmt, v))
1390+
theta = fmt.format(theta)
1391+
if degsym and unit == "deg":
1392+
theta += "\u00b0"
1393+
s += ' angvec = ({} | {})'.format(theta, _vec2s(fmt, v))
14101394
else:
14111395
raise ValueError('bad orientation format')
14121396

spatialmath/base/transformsNd.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from spatialmath.base import transforms2d as t2d
2020
from spatialmath.base import transforms3d as t3d
2121
from spatialmath.base import argcheck
22+
import spatialmath.base.symbolic as sym
2223

2324
_eps = np.finfo(np.float64).eps
2425

@@ -41,21 +42,25 @@ def r2t(R, check=False):
4142
4243
:seealso: t2r, rt2tr
4344
"""
44-
45-
assert isinstance(R, np.ndarray)
4645
dim = R.shape
4746
assert dim[0] == dim[1], 'Matrix must be square'
48-
49-
if check and np.abs(np.linalg.det(R) - 1) < 100 * _eps:
50-
raise ValueError('Invalid rotation matrix ')
51-
52-
# T = np.pad(R, (0, 1), mode='constant')
53-
# T[-1, -1] = 1.0
5447
n = dim[0] + 1
5548
m = dim[0]
56-
T = np.zeros((n, n))
49+
50+
if R.dtype == 'O':
51+
# symbolic matrix
52+
T = np.zeros((n, n), dtype='O')
53+
else:
54+
# numeric matrix
55+
assert isinstance(R, np.ndarray)
56+
if check and np.abs(np.linalg.det(R) - 1) < 100 * _eps:
57+
raise ValueError('Invalid rotation matrix ')
58+
59+
# T = np.pad(R, (0, 1), mode='constant')
60+
# T[-1, -1] = 1.0
61+
T = np.zeros((n, n))
5762
T[:m,:m] = R
58-
T[-1, -1] = 1.0
63+
T[-1, -1] = 1
5964

6065
return T
6166

0 commit comments

Comments
 (0)
0