8000 plotting, features + doco · krenshaw2018/spatialmath-python@c13f1c7 · GitHub
[go: up one dir, main page]

Skip to content

Commit c13f1c7

Browse files
committed
plotting, features + doco
1 parent d3feb68 commit c13f1c7

File tree

2 files changed

+69
-47
lines changed

2 files changed

+69
-47
lines changed

spatialmath/base/graphics.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import matplotlib.pyplot as plt
22
from numpy.core.defchararray import center
3-
from spatialmath.base.vectors import getvector
3+
from spatialmath import base
44
import numpy as np
55
import scipy as sp
66

@@ -161,9 +161,13 @@ def plot_point(pos, marker='bs', text=None, ax=None, color=None, textargs=None,
161161
162162
"""
163163

164-
if isinstance(pos, np.ndarray) and pos.shape[0] == 2:
165-
x = pos[0,:]
166-
y = pos[1,:]
164+
if isinstance(pos, np.ndarray):
165+
if pos.ndim == 1:
166+
x = pos[0]
167+
y = pos[1]
168+
elif pos.ndim == 2 and pos.shape[0] == 2:
169+
x = pos[0,:]
170+
y = pos[1,:]
167171
elif isinstance(pos, (tuple, list)):
168172
# [x, y]
169173
# [(x,y), (x,y), ...]
@@ -199,7 +203,7 @@ def plot_point(pos, marker='bs', text=None, ax=None, color=None, textargs=None,
199203
for i, xy in enumerate(zip(x, y)):
200204
plt.text(xy[0], xy[1], ' ' + text.format(i), color=color, **textopts)
201205
except:
202-
plt.text(x, y, ' ' + text, horizontalalignment='left', verticalalignment='center', color=color, **textopts)
206+
plt.text(x, y, ' ' + text, ha='left', va='center', color=color, **textopts)
203207

204208

205209

@@ -408,7 +412,7 @@ def isnotebook():
408412
except NameError:
409413
return False # Probably standard Python interpreter
410414

411-
def plotvol2(dim, ax=None, equal=False):
415+
def plotvol2(dim, ax=None, equal=True, grid=False):
412416
"""
413417
Create 2D plot area
414418
@@ -436,9 +440,11 @@ def plotvol2(dim, ax=None, equal=False):
436440

437441
if equal:
438442
ax.set_aspect('equal')
443+
if grid:
444+
ax.grid(True)
439445
return ax
440446

441-
def plotvol3(dim, ax=None, equal=False, projection='ortho'):
447+
def plotvol3(dim, ax=None, equal=True, grid=Fal 10000 se, projection='ortho'):
442448
"""
443449
Create 3D plot volume
444450
@@ -468,7 +474,9 @@ def plotvol3(dim, ax=None, equal=False, projection='ortho'):
468474
ax.set_zlabel('Z')
469475

470476
if equal:
471-
ax.set_aspect('equal')
477+
ax.set_box_aspect((1,) * 3)
478+
if grid:
479+
ax.grid(True)
472480
return ax
473481

474482

@@ -496,7 +504,7 @@ def expand_dims(dim=None, nd=2):
496504
* [A,B] -> [A, B, A, B, A, B]
497505
* [A,B,C,D,E,F] -> [A, B, C, D, E, F]
498506
"""
499-
dim = getvector(dim)
507+
dim = base.getvector(dim)
500508

501509
if nd == 2:
502510
if len(dim) == 1:

spatialmath/base/transforms3d.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from math import sin, cos
1919
import numpy as np
2020
from spatialmath import base
21+
from collections.abc import Iterable
2122

2223
_eps = np.finfo(np.float64).eps
2324

@@ -1836,7 +1837,7 @@ def _vec2s(fmt, v):
18361837

18371838
def trplot(T, axes=None, block=False, dims=None, color='blue', frame=None, # pylint: disable=unused-argument,function-redefined
18381839
textcolor=None, labels=('X', 'Y', 'Z'), length=1, style='arrow',
1839-
origindot=None, projection='ortho', wtl=0.2, width=None, d1=0.05,
1840+
originsize=20, origincolor=None, projection='ortho', wtl=0.2, width=None, d1=0.05,
18401841
d2=1.15, anaglyph=None, **kwargs):
18411842
"""
18421843
Plot a 3D coordinate frame
@@ -1851,19 +1852,21 @@ def trplot(T, axes=None, block=False, dims=None, color='blue', frame=None, # p
18511852
If dims is [min, max] those limits are applied to the x-, y- and z-axes.
18521853
:type dims: array_like(6) or array_like(2)
18531854
:param color: color of the lines defining the frame
1854-
:type color: str
1855-
:param textcolor: color of text labels for the frame, default color of lines above
1855+
:type color: str or list(3) of str
1856+
:param textcolor: color of text labels for the frame, default ``color``
18561857
:type textcolor: str
18571858
:param frame: label the frame, name is shown below the frame and as subscripts on the frame axis labels
18581859
:type frame: str
18591860
:param labels: labels for the axes, defaults to X, Y and Z
18601861
:type labels: 3-tuple of strings
18611862
:param length: length of coordinate frame axes, default 1
1862-
:type length: float
1863-
:param style: axis style: 'arrow' [default], 'line', 'rgb' (Rviz style)
1863+
:type length: float or array_like(3)
1864+
:param style: axis style: 'arrow' [default], 'line', 'rviz' (Rviz style)
18641865
:type style: str
1865-
:param origindot: size of dot to draw at the origin (default 20)
1866-
:type origindot: int
1866+
:param originsize: size of dot to draw at the origin, 0 for no dot (default 20)
1867+
:type originsize: int
1868+
:param origincolor: color of dot to draw at the origin, default is ``color``
1869+
:type origincolor: str
18671870
:param anaglyph: 3D anaglyph display, left-right lens colors eg. ``'rc'``
18681871
for red-cyan glasses. To set the disparity (default 0.1) provide second
18691872
argument in a tuple, eg. ``('rc', 0.2)``. Bigger disparity exagerates the
@@ -1906,7 +1909,7 @@ def trplot(T, axes=None, block=False, dims=None, color='blue', frame=None, # p
19061909
19071910
.. note:: The origin is normally indicated with a marker of the same color
19081911
as the frame. The default size is 20. This can be disabled by setting
1909-
its size to zero by ``origindot=0``. For ``'rgb'`` style the default is 0
1912+
its size to zero by ``originsize=0``. For ``'rgb'`` style the default is 0
19101913
but it can be set explicitly, and the color is as per the ``color``
19111914
option.
19121915
@@ -1973,18 +1976,19 @@ def trplot(T, axes=None, block=False, dims=None, color='blue', frame=None, # p
19731976

19741977
return
19751978

1976-
if style == 'rgb':
1977-
if origindot is None:
1978-
origindot = 0
1979-
colors = ('red', 'green', 'blue')
1980-
color = 'k'
1981-
width = 8
1979+
if style == 'rviz':
1980+
if originsize is None:
1981+
originsize = 0
1982+
color = 'rgb'
1983+
if width is None:
1984+
width = 8
19821985
style = 'line'
1983-
else:
1984-
colors = (color,) * 3
1985-
width = 1
1986-
if origindot is None:
1987-
origindot = 20
1986+
1987+
if isinstance(color, str):
1988+
if color == 'rgb':
1989+
color = ('red', 'green', 'blue')
1990+
else:
1991+
color = (color,) * 3
19881992

19891993
# check input types
19901994
if isrot(T, check=True):
@@ -1996,7 +2000,7 @@ def trplot(T, axes=None, block=False, dims=None, color='blue', frame=None, # p
19962000
for Tk in T:
19972001
trplot(Tk, axes=ax, block=block, dims=dims, color=color, frame=frame,
19982002
textcolor=textcolor, labels=labels, length=length, style=style,
1999-
projection=projection, wtl=wtl, width=width, d1=d1,
2003+
projection=projection, originsize=originsize, origincolor=origincolor, wtl=wtl, width=width, d1=d1,
20002004
d2=d2, anaglyph=anaglyph, **kwargs)
20012005
return
20022006

@@ -2008,47 +2012,56 @@ def trplot(T, axes=None, block=False, dims=None, color='blue', frame=None, # p
20082012
ax.set_zlim(dims[4:6])
20092013

20102014
# create unit vectors in homogeneous form
2015+
if not isinstance(length, Iterable):
2016+
length = (length,) * 3
2017+
20112018
o = T @ np.array([0, 0, 0, 1])
2012-
x = T @ np.array([length, 0, 0, 1])
2013-
y = T @ np.array([0, length, 0, 1])
2014-
z = T @ np.array([0, 0, length, 1])
2019+
x = T @ np.array([length[0], 0, 0, 1])
2020+
y = T @ np.array([0, length[1], 0, 1])
2021+
z = T @ np.array([0, 0, length[2], 1])
20152022

20162023
# draw the axes
20172024

20182025
if style == 'arrow':
2019-
ax.quiver(o[0], o[1], o[2], x[0] - o[0], x[1] - o[1], x[2] - o[2], arrow_length_ratio=wtl, linewidth=width, facecolor=color, edgecolor=color)
2020-
ax.quiver(o[0], o[1], o[2], y[0] - o[0], y[1] - o[1], y[2] - o[2], arrow_length_ratio=wtl, linewidth=width, facecolor=color, edgecolor=color)
2021-
ax.quiver(o[0], o[1], o[2], z[0] - o[0], z[1] - o[1], z[2] - o[2], arrow_length_ratio=wtl, linewidth=width, facecolor=color, edgecolor=color)
2026+
ax.quiver(o[0], o[1], o[2], x[0] - o[0], x[1] - o[1], x[2] - o[2], arrow_length_ratio=wtl, linewidth=width, facecolor=color[0], edgecolor=color[1])
2027+
ax.quiver(o[0], o[1], o[2], y[0] - o[0], y[1] - o[1], y[2] - o[2], arrow_length_ratio=wtl, linewidth=width, facecolor=color[1], edgecolor=color[1])
2028+
ax.quiver(o[0], o[1], o[2], z[0] - o[0], z[1] - o[1], z[2] - o[2], arrow_length_ratio=wtl, linewidth=width, facecolor=color[2], edgecolor=color[2])
20222029

20232030
# plot some points
20242031
# invisible point at the end of each arrow to allow auto-scaling to work
20252032
ax.scatter(xs=[o[0], x[0], y[0], z[0]], ys=[o[1], x[1], y[1], z[1]], zs=[o[2], x[2], y[2], z[2]],
20262033
s=[0, 0, 0, 0])
20272034
elif style == 'line':
2028-
ax.plot([o[0], x[0]], [o[1], x[1]], [o[2], x[2]], color=colors[0], linewidth=width)
2029-
ax.plot([o[0], y[0]], [o[1], y[1]], [o[2], y[2]], color=colors[1], linewidth=width)
2030-
ax.plot([o[0], z[0]], [o[1], z[1]], [o[2], z[2]], color=colors[2], linewidth=width)
2031-
2032-
if origindot > 0:
2033-
ax.scatter(xs=[o[0]], ys=[o[1]], zs=[o[2]], color=color, s=origindot)
2035+
ax.plot([o[0], x[0]], [o[1], x[1]], [o[2], x[2]], color=color[0], linewidth=width)
2036+
ax.plot([o[0], y[0]], [o[1], y[1]], [o[2], y[2]], color=color[1], linewidth=width)
2037+
ax.plot([o[0], z[0]], [o[1], z[1]], [o[2], z[2]], color=color[2], linewidth=width)
20342038

20352039
# label the frame
20362040
if frame:
2037-
if textcolor is not None:
2038-
color = textcolor
2041+
if textcolor is None:
2042+
textcolor = color[0]
2043+
else:
2044+
textcolor = 'blue'
2045+
if origincolor is None:
2046+
origincolor = color[0]
2047+
else:
2048+
origincolor = 'black'
20392049

20402050
o1 = T @ np.array([-d1, -d1, -d1, 1])
2041-
ax.text(o1[0], o1[1], o1[2], r'$\{' + frame + r'\}$', color=color, verticalalignment='top', horizontalalignment='center')
2051+
ax.text(o1[0], o1[1], o1[2], r'$\{' + frame + r'\}$', color=textcolor, verticalalignment='top', horizontalalignment='center')
20422052

20432053
# add the labels to each axis
20442054

20452055
x = (x - o) * d2 + o
20462056
y = (y - o) * d2 + o
20472057
z = (z - o) * d2 + o
20482058

2049-
ax.text(x[0], x[1], x[2], "$%c_{%s}$" % (labels[0], frame), color=color, horizontalalignment='center', verticalalignment='center')
2050-
ax.text(y[0], y[1], y[2], "$%c_{%s}$" % (labels[1], frame), color=color, horizontalalignment='center', verticalalignment='center')
2051-
ax.text(z[0], z[1], z[2], "$%c_{%s}$" % (labels[2], frame), color=color, horizontalalignment='center', verticalalignment='center')
2059+
ax.text(x[0], x[1], x[2], "$%c_{%s}$" % (labels[0], frame), color=textcolor, horizontalalignment='center', verticalalignment='center')
2060+
ax.text(y[0], y[1], y[2], "$%c_{%s}$" % (labels[1], frame), color=textcolor, horizontalalignment='center', verticalalignment='center')
2061+
ax.text(z[0], z[1], z[2], "$%c_{%s}$" % (labels[2], frame), color=textcolor, horizontalalignment='center', verticalalignment='center')
2062+
2063+
if originsize > 0:
2064+
ax.scatter(xs=[o[0]], ys=[o[1]], zs=[o[2]], color=origincolor, s=originsize)
20522065

20532066
if block:
20542067
# calling this at all, causes FuncAnimation to fail so when invoked from tranimate skip this bit
@@ -2114,6 +2127,7 @@ def tranimate(T, **kwargs):
21142127
plt.show(block=block)
21152128

21162129
if __name__ == '__main__': # pragma: no cover
2130+
21172131
import pathlib
21182132

21192133
exec(open(pathlib.Path(__file__).parent.parent.parent.absolute() / "tests" / "base" / "test_transforms3d.py").read()) # pylint: disable=exec-used

0 commit comments

Comments
 (0)
0