8000 continue refactoring · Robertleoj/spatialmath-python@bd8b82d · GitHub
[go: up one dir, main page]

Skip to content

Commit bd8b82d

Browse files
committed
continue refactoring
1 parent b959426 commit bd8b82d

18 files changed

+63
-260
lines changed

spatialmath/DualQuaternion.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from spatialmath import Quaternion, UnitQuaternion
44
from spatialmath import base
5+
import spatialmath.pose3d as pose3d
56
from spatialmath.base.types import ArrayLike3, R8x8, R8
67
from typing import Self, overload
78

@@ -31,7 +32,9 @@ class DualQuaternion:
3132
:seealso: :func:`UnitDualQuaternion`
3233
"""
3334

34-
def __init__(self, real: Quaternion = None, dual: Quaternion = None):
35+
def __init__(
36+
self, real: Quaternion | None = None, dual: Quaternion | None = None
37+
) -> None:
3538
"""
3639
Construct a new dual quaternion
3740
@@ -268,7 +271,7 @@ class UnitDualQuaternion(DualQuaternion):
268271
"""
269272

270273
@overload
271-
def __init__(self, T: SE3): ...
274+
def __init__(self, T: pose3d.SE3): ...
272275

273276
@overload
274277
def __init__(self, real: Quaternion, dual: Quaternion): ...
@@ -313,7 +316,7 @@ def __init__(self, real=None, dual=None):
313316
:math:`t`.
314317
"""
315318

316-
if dual is None and isinstance(real, SE3):
319+
if dual is None and isinstance(real, pose3d.SE3):
317320
T = real
318321
S = UnitQuaternion(T.R)
319322
D = Quaternion.Pure(T.t)
@@ -323,7 +326,7 @@ def __init__(self, real=None, dual=None):
323326

324327
super().__init__(real, dual)
325328

326-
def SE3(self) -> SE3:
329+
def SE3(self) -> pose3d.SE3:
327330
"""
328331
Convert unit dual quaternion to SE(3) matrix
329332
@@ -344,18 +347,8 @@ def SE3(self) -> SE3:
344347
R = base.q2r(self.real.A)
345348
t = 2 * self.dual * self.real.conj()
346349

347-
return SE3(base.rt2tr(R, t.v))
350+
return pose3d.SE3(base.rt2tr(R, t.v))
348351

349-
# def exp(self):
350-
# w = self.real.v
351-
# v = self.dual.v
352-
# theta = base.norm(w)
353352

354-
355-
if __name__ == "__main__": # pragma: no cover
356-
from spatialmath import SE3, UnitDualQuaternion
357-
358-
print(UnitDualQuaternion(SE3()))
359-
# import pathlib
360-
361-
# exec(open(pathlib.Path(__file__).parent.parent.absolute() / "tests" / "test_dualquaternion.py").read()) # pylint: disable=exec-used
353+
if __name__ == "__main__":
354+
print(UnitDualQuaternion(pose3d.SE3()))

spatialmath/base/__init__.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# from spatialmath.base.animate import * # lgtm [py/polluting-import]
1313
# from spatialmath.base.graphics import * # lgtm [py/polluting-import]
1414
# from spatialmath.base.numeric import * # lgtm [py/polluting-import]
15+
import spatialmath.base.symbolic as symbolic
1516

1617

1718
from spatialmath.base.graphics import (
@@ -56,6 +57,7 @@
5657
from spatialmath.base.quaternions import (
5758
qpure,
5859
qnorm,
60+
qeye,
5961
qunit,
6062
qisunit,
6163
qisequal,
@@ -135,6 +137,7 @@
135137
trprint,
136138
trplot,
137139
tranimate,
140+
tr2adjoint,
138141
)
139142
from spatialmath.base.transformsNd import (
140143
t2r,
@@ -167,6 +170,7 @@
167170
unittwist,
168171
unittwist_norm,
169172
unittwist2,
173+
unittwist2_norm,
170174
angdiff,
171175
removesmall,
172176
cross,
@@ -175,7 +179,13 @@
175179
wrap_mpi_pi,
176180
)
177181

178-
# from spatialmath.base.symbolic import *
182+
from spatialmath.base.symbolic import (
183+
sqrt,
184+
sin,
185+
cos,
186+
tan,
187+
)
188+
179189
# from spatialmath.base.animate import Animate, Animate2
180190
# from spatialmath.base.graphics import (
181191
# plotvol2,
@@ -200,7 +210,7 @@
200210
# axes_logic,
201211
# isnotebook,
202212
# )
203-
from spatialmath.base.numeric import numjac, array2str, bresenham, numhess
213+
from spatialmath.base.numeric import numjac, array2str, bresenham, numhess, mpq_point
204214

205215

206216
__all__ = [
@@ -219,6 +229,7 @@
219229
# spatialmath.base.quaternions
220230
"qpure",
221231
"qnorm",
232+
"qeye",
222233
"qunit",
223234
"qisunit",
224235
"qisequal",
@@ -335,13 +346,19 @@
335346
"unittwist",
336347
"unittwist_norm",
337348
"unittwist2",
349+
"unittwist2_norm",
338350
"angdiff",
339351
"removesmall",
340352
"cross",
341353
"iszero",
342354
"wrap_0_2pi",
343355
"wrap_mpi_pi",
344356
"wrap_0_pi",
357+
# spatialmath.base.symbolic
358+
"sqrt",
359+
"sin",
360+
"cos",
361+
"tan",
345362
# spatialmath.base.animate
346363
"Animate",
347364
"Animate2",
@@ -378,4 +395,6 @@
378395
"mpq_point",
379396
"gauss1d",
380397
"gauss2d",
398+
# modules
399+
"symbolic",
381400
]

spatialmath/base/transforms2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def rot2(theta: float, unit: str = "rad") -> SO2Array:
5252
>>> rot2(45, 'deg')
5353
"""
5454
theta = smb.getunit(theta, unit, dim=0)
55-
ct = smb.sym.cos(theta)
56-
st = smb.sym.sin(theta)
55+
ct = smb.symbolic.cos(theta)
56+
st = smb.symbolic.sin(theta)
5757
# fmt: off
5858
R = np.array([
5959
[ct, -st],

spatialmath/base/transforms3d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,9 +740,10 @@ def angvec2r(theta: float, v: ArrayLike3, unit="rad", tol: float = 20) -> SO3Arr
740740
return np.eye(3)
741741

742742
θ = getunit(theta, unit)
743+
if isinstance(θ, np.ndarray):
744+
θ = float(θ.squeeze())
743745

744746
# Rodrigue's equation
745-
746747
sk = skew(cast(ArrayLike3, unitvec(v)))
747748
R = np.eye(3) + math.sin(θ) * sk + (1.0 - math.cos(θ)) * sk @ sk
748749
return R

spatialmath/base/transformsNd.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
ArrayLike3,
1414
se2Array,
1515
se3Array,
16-
so2Array,
17-
so3Array,
1816
senArray,
19-
R1,
2017
R6,
2118
ArrayLike6,
2219
)
@@ -518,11 +515,11 @@ def skew(v):
518515

519516
# ---------------------------------------------------------------------------------------#
520517
@overload
521-
def vex(s: so2Array, check: bool = False) -> R1: ...
518+
def vex(s: se2Array, check: bool = False) -> R2: ...
522519

523520

524521
@overload
525-
def vex(s: so3Array, check: bool = False) -> R3: ...
522+
def vex(s: se3Array, check: bool = False) -> R3: ...
526523

527524

528525
def vex(s, check=False):

spatialmath/base/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
"SO2Array",
6262
"SE2Array",
6363
"SE3Array",
64-
6564
"R8x8",
6665
"R3x3",
6766
"RNx3",

spatialmath/base/vectors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> float | NDArray:
567567

568568
y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, n * np.pi - theta)
569569
if isinstance(y, np.ndarray) and len(y) == 1:
570-
return float(y)
570+
return float(y.squeeze())
571571
else:
572572
return y
573573

@@ -585,7 +585,7 @@ def wrap_0_2pi(theta: ArrayLike) -> float | NDArray:
585585
theta = getvector(theta)
586586
y = theta - 2.0 * math.pi * np.floor(theta / 2.0 / np.pi)
587587
if isinstance(y, np.ndarray) and len(y) == 1:
588-
return float(y)
588+
return float(y.squeeze())
589589
else:
590590
return y
591591

@@ -603,7 +603,7 @@ def wrap_mpi_pi(theta: ArrayLike) -> float | NDArray:
603603
theta = getvector(theta)
604604
y = np.mod(theta + math.pi, 2 * math.pi) - np.pi
605605
if isinstance(y, np.ndarray) and len(y) == 1:
606-
return float(y)
606+
return float(y.squeeze())
607607
else:
608608
return y
609609

@@ -658,7 +658,7 @@ def angdiff(a, b=None):
658658

659659
y = np.mod(a + math.pi, 2 * math.pi) - math.pi
660660
if isinstance(y, np.ndarray) and len(y) == 1:
661-
return float(y)
661+
return float(y.squeeze())
662662
else:
663663
return y
664664

spatialmath/geom2d.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
R3,
3030
R4,
3131
)
32-
from typing import Self
32+
from typing import Self, Iterator, cast
3333

3434
_eps = np.finfo(np.float64).eps
3535

@@ -93,7 +93,7 @@ def General(cls, m, c) -> Self:
9393
"""
9494
return cls([m, -1, c])
9595

96-
def general(self) -> Tuple[float, float]:
96+
def general(self) -> tuple[float, float]:
9797
r"""
9898
Parameters of general line
9999
@@ -480,7 +480,7 @@ def animate(self, T, **kwargs) -> None:
480480
self.patch = PathPatch(self.path, **self.kwargs)
481481
self.ax.add_patch(self.patch)
482482

483-
def contains(self, p: ArrayLike2, radius: float = 0.0) -> Union[bool, List[bool]]:
483+
def contains(self, p: ArrayLike2, radius: float = 0.0) -> bool | list[bool]:
484484
"""
485485
Test if point is inside polygon
486486
@@ -564,7 +564,7 @@ def radius(self) -> float:
564564
return dmax
565565

566566
def intersects(
567-
self, other: Union[Polygon2, Line2, List[Polygon2], List[Line2]]
567+
self, other: Polygon2 | Line2 | list[Polygon2] | list[Line2]
568568
) -> bool:
569569
"""
570570
Test for intersection
@@ -590,12 +590,12 @@ def intersects(
590590
return True
591591
return False
592592
elif smb.islistof(other, Polygon2):
593-
for polygon in cast(List[Polygon2], other):
593+
for polygon in cast(list[Polygon2], other):
594594
if self.path.intersects_path(polygon.path, filled=True):
595595
return True
596596
return False
597597
elif smb.islistof(other, Line2):
598-
for line in cast(List[Line2], other):
598+
for line in cast(list[Line2], other):
599599
for p1, p2 in self.edges():
600600
# test each edge segment against the line
601601
if line.intersect_segment(p1, p2):
@@ -960,7 +960,7 @@ def theta(self) -> float:
960960
"""
961961
e, x = np.linalg.eigh(self.E)
962962
# major axis is second column
963-
return np.arctan(x[1, 1] / x[0, 1])
963+
return np.atan2(x[1, 1], x[0, 1]) % (np.pi)
964964

965965
@property
966966
def area(self) -> float:

spatialmath/quaternion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ class Quaternion(BasePoseList):
5151
:parts: 1
5252
"""
5353

54-
def __init__(self, s: Any = None, v=None) -> None:
54+
def __init__(self, s: Any = None, v=None, check: bool = True) -> None:
5555
r"""
5656
Construct a new quaternion
5757
5858
:param s: scalar
5959
:type s: float
6060
:param v: vector
6161
:type v: 3-element array_like
62-
62+
:param check: does nothing, for compatibility with other classes
6363
- ``Quaternion()`` constructs a zero quaternion
6464
- ``Quaternion(s, v)`` construct a new quaternion from the scalar ``s``
6565
and the vector ``v``

spatialmath/spline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def visualize(
6060
) # plot compare to input poses
6161

6262
if animate:
63+
print("Before animate")
6364
tranimate(
6465
samples, length=pose_marker_length, wait=True, repeat=repeat
6566
) # animate pose along trajectory

tests/base/test_symbolic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,37 +18,37 @@
1818

1919
class Test_symbolic(unittest.TestCase):
2020
def test_symbol(self):
21-
theta = sp.symbols("theta")
21+
theta = sp.symbols("theta", real=True)
2222
self.assertTrue(isinstance(theta, sp.Expr))
2323
self.assertTrue(theta.is_real)
2424

2525
theta = sp.symbols("theta", real=False)
2626
self.assertTrue(isinstance(theta, sp.Expr))
2727
self.assertFalse(theta.is_real)
2828

29-
theta, psi = sp.symbols("theta, psi")
29+
theta, psi = sp.symbols("theta, psi", real=True)
3030
self.assertTrue(isinstance(theta, sp.Expr))
3131
self.assertTrue(isinstance(psi, sp.Expr))
3232

33-
theta, psi = sp.symbols("theta psi")
33+
theta, psi = sp.symbols("theta psi", real=True)
3434
self.assertTrue(isinstance(theta, sp.Expr))
3535
self.assertTrue(isinstance(psi, sp.Expr))
3636

37-
q = sp.symbols("q:6")
37+
q = sp.symbols("q:6", real=True)
3838
self.assertEqual(len(q), 6)
3939
for _ in q:
4040
self.assertTrue(isinstance(_, sp.Expr))
4141
self.assertTrue(_.is_real)
4242

4343
def test_issymbol(self):
44-
theta = sp.symbols("theta")
44+
theta = sp.symbols("theta", real=True)
4545
self.assertFalse(issymbol(3))
4646
self.assertFalse(issymbol("not a symbol"))
4747
self.assertFalse(issymbol([1, 2]))
4848
self.assertTrue(issymbol(theta))
4949

5050
def test_functions(self):
51-
theta = sp.symbols("theta")
51+
theta = sp.symbols("theta", real=True)
5252
self.assertTrue(isinstance(sin(theta), sp.Expr))
5353
self.assertTrue(isinstance(sin(1.0), float))
5454

0 commit comments

Comments
 (0)
0