8000 modest change to semantics of getunit() and code changes to suit · RedCarp0/spatialmath-python@9028e46 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9028e46

Browse files
committed
modest change to semantics of getunit() and code changes to suit
1 parent 4ef89e3 commit 9028e46

File tree

7 files changed

+123
-117
lines changed

7 files changed

+123
-117
lines changed

spatialmath/base/argcheck.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import math
1515
import numpy as np
16+
from collections.abc import Iterable
1617

1718
# from spatialmath.base import symbolic as sym # HACK
1819
from spatialmath.base.symbolic import issymbol, symtype
@@ -371,6 +372,8 @@ def getvector(
371372
>>> getvector(1) # scalar
372373
>>> getvector([1])
373374
>>> getvector([[1]])
375+
>>> getvector([1,2], 2)
376+
>>> # getvector([1,2], 3) --> ValueError
374377
375378
.. note::
376379
- For 'array', 'row' or 'col' output the NumPy dtype defaults to the
@@ -519,68 +522,62 @@ def isvector(v: Any, dim: Optional[int] = None) -> bool:
519522
return False
520523

521524

522-
@overload
523-
def getunit(v: float, unit: str = "rad") -> float: # pragma: no cover
524-
...
525-
526-
527-
@overload
528-
def getunit(v: NDArray, unit: str = "rad") -> NDArray: # pragma: no cover
529-
...
530-
531-
532-
@overload
533-
def getunit(v: List[float], unit: str = "rad") -> List[float]: # pragma: no cover
534-
...
535-
536-
537-
@overload
538-
def getunit(v: Tuple[float, ...], unit: str = "rad") -> List[float]: # pragma: no cover
539-
...
540-
541-
542-
def getunit(
543-
v: Union[float, NDArray, Tuple[float, ...], List[float]], unit: str = "rad"
544-
) -> Union[float, NDArray, List[float]]:
525+
def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
545526
"""
546-
Convert value according to angular units
527+
Convert values according to angular units
547528
548529
:param v: the value in radians or degrees
549-
:type v: array_like(m) or ndarray(m)
530+
:type v: array_like(m)
550531
:param unit: the angular unit, "rad" or "deg"
551532
:type unit: str
533+
:param dim: expected dimension of input, defaults to None
534+
:type dim: int, optional
552535
:return: the converted value in radians
553-
:rtype: list(m) or ndarray(m)
536+
:rtype: ndarray(m) or float
554537
:raises ValueError: argument is not a valid angular unit
555538
556-
The input can be a list or ndarray() and the output is the same type.
539+
The input value is assumed to be in units of ``unit`` and is converted to radians.
557540
558541
.. runblock:: pycon
559542
560543
>>> from spatialmath.base import getunit
561544
>>> import numpy as np
562545
>>> getunit(1.5, 'rad')
546+
>>> getunit(1.5, 'rad', dim=0)
547+
>>> # getunit([1.5], 'rad', dim=0) --> ValueError
563548
>>> getunit(90, 'deg')
564549
>>> getunit([90, 180], 'deg')
565550
>>> getunit(np.r_[0.5, 1], 'rad')
566551
>>> getunit(np.r_[90, 180], 'deg')
552+
>>> getunit(np.r_[90, 180], 'deg', dim=2)
553+
>>> # getunit([90, 180], 'deg', dim=3) --> ValueError
554+
555+
:note:
556+
- the input value is processed by :func:`getvector` and the argument ``dim`` can
557+
be used to check that ``v`` is the desired length.
558+
- the output is always an ndarray except if the input is a scalar and ``dim=0``.
559+
560+
:seealso: :func:`getvector`
567561
"""
568-
if unit == "rad":
569-
if isinstance(v, tuple):
570-
return list(v)
571-
else:
562+
if not isinstance(v, Iterable) and dim == 0:
563+
# scalar in, scalar out
564+
if unit == "rad":
572565
return v
573-
elif unit == "deg":
574-
if isinstance(v, np.ndarray) or np.isscalar(v):
575-
return v * math.pi / 180 # type: ignore
576-
elif isinstance(v, (list, tuple)):
577-
return [x * math.pi / 180 for x in v]
566+
elif unit == "deg":
567+
return np.deg2rad(v)
578568
else:
579-
raise ValueError("bad argument")
580-
else:
581-
raise ValueError("invalid angular units")
569+
raise ValueError("invalid angular units")
582570

583-
return ret
571+
else:
572+
# scalar or iterable in, ndarray out
573+
# iterable passed in
574+
v = getvector(v, dim=dim)
575+
if unit == "rad":
576+
return v
577+
elif unit == "deg":
578+
return np.deg2rad(v)
579+
else:
580+
raise ValueError("invalid angular units")
584581

585582

586583
def isnumberlist(x: Any) -> bool:

spatialmath/base/transforms2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def rot2(theta: float, unit: str = "rad") -> SO2Array:
6262
>>> rot2(0.3)
6363
>>> rot2(45, 'deg')
6464
"""
65-
theta = smb.getunit(theta, unit)
65+
theta = smb.getunit(theta, unit, dim=0)
6666
ct = smb.sym.cos(theta)
6767
st = smb.sym.sin(theta)
6868
# fmt: off

spatialmath/base/transforms3d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def rotx(theta: float, unit: str = "rad") -> SO3Array:
7979
:SymPy: supported
8080
"""
8181

82-
theta = getunit(theta, unit)
82+
theta = getunit(theta, unit, dim=0)
8383
ct = sym.cos(theta)
8484
st = sym.sin(theta)
8585
# fmt: off
@@ -118,7 +118,7 @@ def roty(theta: float, unit: str = "rad") -> SO3Array:
118118
:SymPy: supported
119119
"""
120120

121-
theta = getunit(theta, unit)
121+
theta = getunit(theta, unit, dim=0)
122122
ct = sym.cos(theta)
123123
st = sym.sin(theta)
124124
# fmt: off
@@ -152,7 +152,7 @@ def rotz(theta: float, unit: str = "rad") -> SO3Array:
152152
:seealso: :func:`~trotz`
153153
:SymPy: supported
154154
"""
155-
theta = getunit(theta, unit)
155+
theta = getunit(theta, unit, dim=0)
156156
ct = sym.cos(theta)
157157
st = sym.sin(theta)
158158
# fmt: off

spatialmath/quaternion.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,12 +1118,12 @@ def vec3(self) -> R3:
11181118

11191119
# -------------------------------------------- constructor variants
11201120
@classmethod
1121-
def Rx(cls, angle: float, unit: Optional[str] = "rad") -> UnitQuaternion:
1121+
def Rx(cls, angles: ArrayLike, unit: Optional[str] = "rad") -> UnitQuaternion:
11221122
"""
11231123
Construct a UnitQuaternion object representing rotation about the X-axis
11241124
11251125
:arg θ: rotation angle
1126-
:type θ: float or array_like
1126+
:type θ: array_like
11271127
:arg unit: rotation unit 'rad' [default] or 'deg'
11281128
:type unit: str
11291129
:return: unit-quaternion
@@ -1142,18 +1142,18 @@ def Rx(cls, angle: float, unit: Optional[str] = "rad") -> UnitQuaternion:
11421142
>>> print(UQ.Rx(0.3))
11431143
>>> print(UQ.Rx([0, 0.3, 0.6]))
11441144
"""
1145-
angles = smb.getunit(smb.getvector(angle), unit)
1145+
angles = smb.getunit(angles, unit)
11461146
return cls(
11471147
[np.r_[math.cos(a / 2), math.sin(a / 2), 0, 0] for a in angles], check=False
11481148
)
11491149

11501150
@classmethod
1151-
def Ry(cls, angle: float, unit: Optional[str] = "rad") -> UnitQuaternion:
1151+
def Ry(cls, angles: ArrayLike, unit: Optional[str] = "rad") -> UnitQuaternion:
11521152
"""
11531153
Construct a UnitQuaternion object representing rotation about the Y-axis
11541154
11551155
:arg θ: rotation angle
1156-
:type θ: float or array_like
1156+
:type θ: array_like
11571157
:arg unit: rotation unit 'rad' [default] or 'deg'
11581158
:type unit: str
11591159
:return: unit-quaternion
@@ -1172,18 +1172,18 @@ def Ry(cls, angle: float, unit: Optional[str] = "rad") -> UnitQuaternion:
11721172
>>> print(UQ.Ry(0.3))
11731173
>>> print(UQ.Ry([0, 0.3, 0.6]))
11741174
"""
1175-
angles = smb.getunit(smb.getvector(angle), unit)
1175+
angles = smb.getunit(angles, unit)
11761176
return cls(
11771177
[np.r_[math.cos(a / 2), 0, math.sin(a / 2), 0] for a in angles], check=False
11781178
)
11791179

11801180
@classmethod
1181-
def Rz(cls, angle: float, unit: Optional[str] = "rad") -> UnitQuaternion:
1181+
def Rz(cls, angles: ArrayLike, unit: Optional[str] = "rad") -> UnitQuaternion:
11821182
"""
11831183
Construct a UnitQuaternion object representing rotation about the Z-axis
11841184
11851185
:arg θ: rotation angle
1186-
:type θ: float or array_like
1186+
:type θ: array_like
11871187
:arg unit: rotation unit 'rad' [default] or 'deg'
11881188
:type unit: str
11891189
:return: unit-quaternion
@@ -1202,7 +1202,7 @@ def Rz(cls, angle: float, unit: Optional[str] = "rad") -> UnitQuaternion:
12021202
>>> print(UQ.Rz(0.3))
12031203
>>> print(UQ.Rz([0, 0.3, 0.6]))
12041204
"""
1205-
angles = smb.getunit(smb.getvector(angle), unit)
1205+
angles = smb.getunit(angles, unit)
12061206
return cls(
12071207
[np.r_[math.cos(a / 2), 0, 0, math.sin(a / 2)] for a in angles], check=False
12081208
)
< F438 /div>
@@ -1390,8 +1390,7 @@ def AngVec(
13901390
:seealso: :meth:`UnitQuaternion.angvec` :meth:`UnitQuaternion.exp` :func:`~spatialmath.base.transforms3d.angvec2r`
13911391
"""
13921392
v = smb.getvector(v, 3)
1393-
smb.isscalar(theta)
1394-
theta = smb.getunit(theta, unit)
1393+
theta = smb.getunit(theta, unit, dim=0)
13951394
return cls(
13961395
s=math.cos(theta / 2), v=math.sin(theta / 2) * v, norm=False, check=False
13971396
)

spatialmath/twist.py

Lines changed: 12 additions & 13 deletions
< 10000 tr class="diff-line-row">
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ def SE3(self, theta=1, unit="rad"):
10081008

10091009
theta = smb.getunit(theta, unit)
10101010

1011-
if smb.isscalar(theta):
1011+
if len(theta) == 1:
10121012
# theta is a scalar
10131013
return SE3(smb.trexp(self.S * theta))
10141014
else:
@@ -1064,7 +1064,7 @@ def exp(self, theta=1, unit="rad"):
10641064
"""
10651065
from spatialmath.pose3d import SE3
10661066

1067-
theta = np.r_[smb.getunit(theta, unit)]
1067+
theta = smb.getunit(theta, unit)
10681068

10691069
if len(self) == 1:
10701070
return SE3([smb.trexp(self.S * t) for t in theta], check=False)
@@ -1524,12 +1524,9 @@ def SE2(self, theta=1, unit="rad"):
15241524
if unit != "rad" and self.isprismatic:
15251525
print("Twist3.exp: using degree mode for a prismatic twist")
15261526

1527-
if theta is None:
1528-
theta = 1
1529-
else:
1530-
theta = smb.getunit(theta, unit)
1527+
theta = smb.getunit(theta, unit)
15311528

1532-
if smb.isscalar(theta):
1529+
if len(theta) == 1:
15331530
return SE2(smb.trexp2(self.S * theta))
15341531
else:
15351532
return SE2([smb.trexp2(self.S * t) for t in theta])
@@ -1560,7 +1557,7 @@ def skewa(self):
15601557
else:
15611558
return [smb.skewa(x.S) for x in self]
15621559

1563-
def exp(self, theta=None, unit="rad"):
1560+
def exp(self, theta=1, unit="rad"):
15641561
r"""
15651562
Exponentiate a 2D twist
15661563
@@ -1595,12 +1592,14 @@ def exp(self, theta=None, unit="rad"):
15951592
"""
15961593
from spatialmath.pose2d import SE2
15971594

1598-
if theta is None:
1599-
theta = 1.0
1600-
else:
1601-
theta = smb.getunit(theta, unit)
1595+
theta = smb.getunit(theta, unit)
16021596

1603-
return SE2(smb.trexp2(self.S * theta))
1597+
if len(self) == 1:
1598+
return SE2([smb.trexp2(self.S * t) for t in theta], check=False)
1599+
elif len(self) == len(theta):
1600+
return SE2([smb.trexp2(s * t) for s, t in zip(self.S, theta)], check=False)
1601+
else:
1602+
raise ValueError("length mismatch")
16041603

16051604
def unit(self):
16061605
"""

tests/base/test_argcheck.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def test_ismatrix(self):
2828
self.assertFalse(ismatrix(1, (-1, -1)))
2929

3030
def test_assertmatrix(self):
31-
3231
with self.assertRaises(TypeError):
3332
assertmatrix(3)
3433
with self.assertRaises(TypeError):
@@ -53,7 +52,6 @@ def test_assertmatrix(self):
5352
assertmatrix(a, (None, 4))
5453

5554
def test_getmatrix(self):
56-
5755
a = np.random.rand(4, 3)
5856
self.assertEqual(getmatrix(a, (4, 3)).shape, (4, 3))
5957
self.assertEqual(getmatrix(a, (None, 3)).shape, (4, 3))
@@ -124,19 +122,26 @@ def test_verifymatrix(self):
124122
verifymatrix(a, (3, 4))
125123

126124
def test_unit(self):
125+
self.assertIsInstance(getunit(1), np.ndarray)
126+
self.assertIsInstance(getunit([1, 2]), np.ndarray)
127+
self.assertIsInstance(getunit((1, 2)), np.ndarray)
128+
self.assertIsInstance(getunit(np.r_[1, 2]), np.ndarray)
129+
self.assertIsInstance(getunit(1.0, dim=0), float)
130+
127131
nt.assert_equal(getunit(5, "rad"), 5)
128132
nt.assert_equal(getunit(5, "deg"), 5 * math.pi / 180.0)
129133
nt.assert_equal(getunit([3, 4, 5], "rad"), [3, 4, 5])
130-
nt.assert_equal(
134+
nt.assert_almost_equal(
131135
getunit([3, 4, 5], "deg"), [x * math.pi / 180.0 for x in [3, 4, 5]]
132136
)
133137
nt.assert_equal(getunit((3, 4, 5), "rad"), [3, 4, 5])
134-
nt.assert_equal(
135-
getunit((3, 4, 5), "deg"), [x * math.pi / 180.0 for x in [3, 4, 5]]
138+
nt.assert_almost_equal(
139+
getunit((3, 4, 5), "deg"),
140+
np.array([x * math.pi / 180.0 for x in [3, 4, 5]]),
136141
)
137142

138143
nt.assert_equal(getunit(np.array([3, 4, 5]), "rad"), [3, 4, 5])
139-
nt.assert_equal(
144+
nt.assert_almost_equal(
140145
getunit(np.array([3, 4, 5]), "deg"),
141146
[x * math.pi / 180.0 for x in [3, 4, 5]],
142147
)
@@ -439,7 +444,6 @@ def test_isvectorlist(self):
439444
self.assertFalse(isvectorlist(a, 2))
440445

441446
def test_islistof(self):
442-
443447
a = [3, 4, 5]
444448
self.assertTrue(islistof(a, int))
445449
self.assertFalse(islistof(a, float))
@@ -457,5 +461,4 @@ def test_islistof(self):
457461

458462
# ---------------------------------------------------------------------------------------#
459463
if __name__ == "__main__": # pragma: no cover
460-
461464
unittest.main()

0 commit comments

Comments
 (0)
0