8000 make norm() handle symbolic arguments · StephLin/spatialmath-python@ed38605 · GitHub
[go: up one dir, main page]

Skip to content

Commit ed38605

Browse files
committed
make norm() handle symbolic arguments
1 parent 3399e7e commit ed38605

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

spatialmath/base/test/test_vectors.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from scipy.linalg import logm, expm
1616

1717
from spatialmath.base.vectors import *
18-
18+
from spatialmath.base import sym
1919
import matplotlib.pyplot as plt
2020

2121

@@ -74,8 +74,13 @@ def test_isunitvec(self):
7474
self.assertFalse(isunitvec([-2]))
7575

7676
def test_norm(self):
77-
nt.assert_array_almost_equal(norm([0, 0, 0]), 0)
78-
nt.assert_array_almost_equal(norm([1, 2, 3]), math.sqrt(14))
77+
self.assertAlmostEqual(norm([0, 0, 0]), 0)
78+
self.assertAlmostEqual(norm([1, 2, 3]), math.sqrt(14))
79+
80+
x, y = sym.symbol('x y')
81+
v = [x, y]
82+
self.assertEqual(norm(v), sym.sqrt(x**2 + y**2))
83+
self.assertEqual(norm(np.r_[v]), sym.sqrt(x**2 + y**2))
7984

8085
def test_isunittwist(self):
8186
# 3D

spatialmath/base/vectors.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@
1515
import numpy as np
1616
from spatialmath.base import getvector
1717

18+
try: # pragma: no cover
19+
# print('Using SymPy')
20+
from sympy import Matrix
21+
22+
_symbolics = True
23+
24+
except ImportError: # pragma: no cover
25+
_symbolics = False
26+
1827
_eps = np.finfo(np.float64).eps
1928

2029

@@ -60,7 +69,7 @@ def unitvec(v):
6069
"""
6170

6271
v = getvector(v)
63-
n = np.linalg.norm(v)
72+
n = norm(v)
6473

6574
if n > 100 * _eps: # if greater than eps
6675
return v / n
@@ -115,8 +124,13 @@ def norm(v):
115124
116125
:seealso: :func:`~spatialmath.base.unit`
117126
127+
:SymPy: supported
118128
"""
119-
return np.linalg.norm(v)
129+
v = getvector(v)
130+
if v.dtype.kind == 'O':
131+
return Matrix(v).norm()
132+
else:
133+
return np.linalg.norm(v)
120134

121135

122136
def isunitvec(v, tol=10):

0 commit comments

Comments
 (0)
0