8000 Issues #101: SpatialInertia.__add__ method fix and test (#109) · petercorke/spatialmath-python@f3e28f8 · GitHub
[go: up one dir, main page]

Skip to content

Commit f3e28f8

Browse files
authored
Issues bdaiinstitute#101: SpatialInertia.__add__ method fix and test (bdaiinstitute#109)
1 parent a0bb12f commit f3e28f8

File tree

2 files changed

+36
-26
lines changed

2 files changed

+36
-26
lines changed

spatialmath/spatialvector.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ def __init__(self, m=None, r=None, I=None):
543543
:param I: inertia about the centre of mass, axes aligned with link frame
544544
:type I: numpy.array, shape=(6,6)
545545
546-
- ``SpatialInertia(m, r I)`` is a spatial inertia object for a rigid-body
546+
- ``SpatialInertia(m, r, I)`` is a spatial inertia object for a rigid-body
547547
with mass ``m``, centre of mass at ``r`` relative to the link frame, and an
548548
inertia matrix ``I`` (3x3) about the centre of mass.
549549
@@ -588,8 +588,9 @@ def isvalid(self, x, check):
588588
:return: True if the matrix has shape (6,6).
589589
:rtype: bool
590590
"""
591-
return self.shape == SpatialVector.shape
591+
return self.shape == x.shape
592592

593+
@property
593594
def shape(self):
594595
"""
595596
Shape of the object's interal matrix representation
@@ -603,7 +604,6 @@ def __getitem__(self, i):
603604
return SpatialInertia(self.data[i])
604605

605606
def __repr__(self):
606-
607607
"""
608608
Convert to string
609609
@@ -634,7 +634,7 @@ def __add__(
634634
"""
635635
if not isinstance(right, SpatialInertia):
636636
raise TypeError("can only add spatial inertia to spatial inertia")
637-
return SpatialInertia(left.I + left.I)
637+
return SpatialInertia(left.A + right.A)
638638

639639
def __mul__(
640640
left, right
@@ -682,7 +682,6 @@ def __rmul__(
682682

683683

684684
if __name__ == "__main__":
685-
686685
import numpy.testing as nt
687686
import pathlib
688687

tests/test_spatialvector.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import unittest
32
import numpy.testing as nt
43
import numpy as np
@@ -55,8 +54,8 @@ def test_velocity(self):
5554

5655
s = str(a)
5756
self.assertIsInstance(s, str)
58-
self.assertEqual(s.count('\n'), 0)
59-
self.assertTrue(s.startswith('SpatialVelocity'))
57+
self.assertEqual(s.count("\n"), 0)
58+
self.assertTrue(s.startswith("SpatialVelocity"))
6059

6160
r = np.random.rand(6, 10)
6261
a = SpatialVelocity(r)
@@ -70,11 +69,11 @@ def test_velocity(self):
7069
self.assertIsInstance(b, SpatialVector)
7170
self.assertIsInstance(b, SpatialM6)
7271
self.assertEqual(len(b), 1)
73-
self.assertTrue(all(b.A == r[:,3]))
72+
self.assertTrue(all(b.A == r[:, 3]))
7473

7574
s = str(a)
7675
self.assertIsInstance(s, str)
77-
self.assertEqual(s.count('\n'), 9)
76+
self.assertEqual(s.count("\n"), 9)
7877

7978
def test_acceleration(self):
8079
a = SpatialAcceleration([1, 2, 3, 4, 5, 6])
@@ -93,8 +92,8 @@ def test_acceleration(self):
9392

9493
s = str(a)
9594
self.assertIsInstance(s, str)
96-
self.assertEqual(s.count('\n'), 0)
97-
self.assertTrue(s.startswith('SpatialAcceleration'))
95+
self.assertEqual(s.count("\n"), 0)
96+
self.assertTrue(s.startswith("SpatialAcceleration"))
9897

9998
r = np.random.rand(6, 10)
10099
a = SpatialAcceleration(r)
@@ -108,14 +107,12 @@ def test_acceleration(self):
108107
self.assertIsInstance(b, SpatialVector)
109108
self.assertIsInstance(b, SpatialM6)
110109
self.assertEqual(len(b), 1)
111-
self.assertTrue(all(b.A == r[:,3]))
110+
self.assertTrue(all(b.A == r[:, 3]))
112111

113112
s = str(a)
114113
self.assertIsInstance(s, str)
115114

116-
117115
def test_force(self):
118-
119116
a = SpatialForce([1, 2, 3, 4, 5, 6])
120117
self.assertIsInstance(a, SpatialForce)
121118
self.assertIsInstance(a, SpatialVector)
@@ -132,8 +129,8 @@ def test_force(self):
132129

133130
s = str(a)
134131
self.assertIsInstance(s, str)
135-
self.assertEqual(s.count('\n'), 0)
136-
self.assertTrue(s.startswith('SpatialForce'))
132+
self.assertEqual(s.count("\n"), 0)
133+
self.assertTrue(s.startswith("SpatialForce"))
137134

138135
r = np.random.rand(6, 10)
139136
a = SpatialForce(r)
@@ -153,7 +150,6 @@ def test_force(self):
153150
self.assertIsInstance(s, str)
154151

155152
def test_momentum(self):
156-
157153
a = SpatialMomentum([1, 2, 3, 4, 5, 6])
158154
self.assertIsInstance(a, SpatialMomentum)
159155
self.assertIsInstance(a, SpatialVector)
@@ -170,8 +166,8 @@ def test_momentum(self):
170166

171167
s = str(a)
172168
self.assertIsInstance(s, str)
173-
self.assertEqual(s.count('\n'), 0)
174-
self.assertTrue(s.startswith('SpatialMomentum'))
169+
self.assertEqual(s.count("\n"), 0)
170+
self.assertTrue(s.startswith("SpatialMomentum"))
175171

176172
r = np.random.rand(6, 10)
177173
a = SpatialMomentum(r)
@@ -190,9 +186,7 @@ def test_momentum(self):
190186
s = str(a)
191187
self.assertIsInstance(s, str)
192188

193-
194189
def test_arith(self):
195-
196190
# just test SpatialVelocity since all types derive from same superclass
197191

198192
r1 = np.r_[1, 2, 3, 4, 5, 6]
@@ -206,8 +200,26 @@ def test_arith(self):
206200

207201
def test_inertia(self):
208202
# constructor
203+
i0 = SpatialInertia()
204+
nt.assert_equal(i0.A, np.zeros((6, 6)))
205+
206+
i1 = SpatialInertia(np.eye(6, 6))
207+
nt.assert_equal(i1.A, np.eye(6, 6))
208+
209+
i2 = SpatialInertia(m=1, r=(1, 2, 3))
210+
nt.assert_almost_equal(i2.A, i2.A.T)
211+
212+
i3 = SpatialInertia(m=1, r=(1, 2, 3), I=np.ones((3, 3)))
213+
nt.assert_almost_equal(i3.A, i3.A.T)
214+
209215
# addition
210-
pass
216+
m_a, m_b = 1.1, 2.2
217+
r = (1, 2, 3)
218+
i4a, i4b = SpatialInertia(m=m_a, r=r), SpatialInertia(m=m_b, r=r)
219+
nt.assert_almost_equal((i4a + i4b).A, SpatialInertia(m=m_a + m_b, r=r).A)
220+
221+
# isvalid - note this method is very barebone, to be improved
222+
self.assertTrue(SpatialInertia().isvalid(np.ones((6, 6)), check=False))
211223

212224
def test_products(self):
213225
# v x v = a *, v x F6 = a
@@ -218,6 +230,5 @@ def test_products(self):
218230

219231

220232
# ---------------------------------------------------------------------------------------#
221-
if __name__ == '__main__':
222-
223-
unittest.main()
233+
if __name__ == "__main__":
234+
unittest.main()

0 commit comments

Comments
 (0)
0