diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a3b93ca..7c28f075 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,10 @@ repos: -# - repo: https://github.com/charliermarsh/ruff-pre-commit -# # Ruff version. -# rev: 'v0.1.0' -# hooks: -# - id: ruff -# args: ['--fix', '--config', 'pyproject.toml'] +- repo: https://github.com/charliermarsh/ruff-pre-commit + # Ruff version. + rev: 'v0.1.0' + hooks: + - id: ruff + args: ['--fix', '--config', 'pyproject.toml'] - repo: https://github.com/psf/black rev: 23.10.0 @@ -14,6 +14,21 @@ repos: args: ['--config', 'pyproject.toml'] verbose: true +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: end-of-file-fixer + - id: debug-statements # Ensure we don't commit `import pdb; pdb.set_trace()` + exclude: | + (?x)^( + docker/ros/web/static/.*| + )$ + - id: trailing-whitespace + exclude: | + (?x)^( + docker/ros/web/static/.*| + (.*/).*\.patch| + )$ # - repo: https://github.com/pre-commit/mirrors-mypy # rev: v1.6.1 # hooks: diff --git a/pyproject.toml b/pyproject.toml index 23168906..cead2ac3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spatialmath-python" -version = "1.1.12" +version = "1.1.13" authors = [ { name="Peter Corke", email="rvc@petercorke.com" }, ] diff --git a/spatialmath/__init__.py b/spatialmath/__init__.py index e6ef1f77..18cb74b4 100644 --- a/spatialmath/__init__.py +++ b/spatialmath/__init__.py @@ -15,7 +15,7 @@ ) from spatialmath.quaternion import Quaternion, UnitQuaternion from spatialmath.DualQuaternion import DualQuaternion, UnitDualQuaternion -from spatialmath.spline import BSplineSE3 +from spatialmath.spline import BSplineSE3, InterpSplineSE3, SplineFit # from spatialmath.Plucker import * # from spatialmath import base as smb @@ -45,6 +45,8 @@ "Polygon2", "Ellipse", "BSplineSE3", + "InterpSplineSE3", + "SplineFit" ] try: diff --git a/spatialmath/base/animate.py b/spatialmath/base/animate.py index 7654a5a0..a2e31f72 100755 --- a/spatialmath/base/animate.py +++ b/spatialmath/base/animate.py @@ -217,7 +217,7 @@ def update(frame, animation): if isinstance(frame, float): # passed a single transform, interpolate it T = smb.trinterp(start=self.start, end=self.end, s=frame) - elif isinstance(frame, NDArray): + elif isinstance(frame, np.ndarray): # type is SO3Array or SE3Array when Animate.trajectory is not None T = frame else: diff --git a/spatialmath/spline.py b/spatialmath/spline.py index f81dcec3..0a472ecc 100644 --- a/spatialmath/spline.py +++ b/spatialmath/spline.py @@ -2,20 +2,242 @@ # MIT Licence, see details in top-level file: LICENCE """ -Classes for parameterizing a trajectory in SE3 with B-splines. - -Copies parts of the API from scipy's B-spline class. +Classes for parameterizing a trajectory in SE3 with splines. """ -from typing import Any, Dict, List, Optional -from scipy.interpolate import BSpline -from spatialmath import SE3 -import numpy as np +from abc import ABC, abstractmethod +from functools import cached_property +from typing import List, Optional, Tuple, Set + import matplotlib.pyplot as plt -from spatialmath.base.transforms3d import tranimate, trplot +import numpy as np +from scipy.interpolate import BSpline, CubicSpline +from scipy.spatial.transform import Rotation, RotationSpline + +from spatialmath import SE3, SO3, Twist3 +from spatialmath.base.transforms3d import tranimate + + +class SplineSE3(ABC): + def __init__(self) -> None: + self.control_poses: SE3 + + @abstractmethod + def __call__(self, t: float) -> SE3: + pass + + def visualize( + self, + sample_times: List[float], + input_trajectory: Optional[List[SE3]] = None, + pose_marker_length: float = 0.2, + animate: bool = False, + repeat: bool = True, + ax: Optional[plt.Axes] = None, + ) -> None: + """Displays an animation of the trajectory with the control poses against an optional input trajectory. + + Args: + sample_times: which times to sample the spline at and plot + """ + if ax is None: + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(projection="3d") + + samples = [self(t) for t in sample_times] + if not animate: + pos = np.array([pose.t for pose in samples]) + ax.plot( + pos[:, 0], pos[:, 1], pos[:, 2], "c", linewidth=1.0 + ) # plot spline fit + + pos = np.array([pose.t for pose in self.control_poses]) + ax.plot(pos[:, 0], pos[:, 1], pos[:, 2], "r*") # plot control_poses + + if input_trajectory is not None: + pos = np.array([pose.t for pose in input_trajectory]) + ax.plot( + pos[:, 0], pos[:, 1], pos[:, 2], "go", fillstyle="none" + ) # plot compare to input poses + + if animate: + tranimate( + samples, length=pose_marker_length, wait=True, repeat=repeat + ) # animate pose along trajectory + else: + plt.show() + + +class InterpSplineSE3(SplineSE3): + """Class for an interpolated trajectory in SE3, as a function of time, through control_poses with a cubic spline. + + A combination of scipy.interpolate.CubicSpline and scipy.spatial.transform.RotationSpline (itself also cubic) + under the hood. + """ + + _e = 1e-12 + + def __init__( + self, + timepoints: List[float], + control_poses: List[SE3], + *, + normalize_time: bool = False, + bc_type: str = "not-a-knot", # not-a-knot is scipy default; None is invalid + ) -> None: + """Construct a InterpSplineSE3 object + + Extends the scipy CubicSpline object + https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html#cubicspline + + Args : + timepoints : list of times corresponding to provided poses + control_poses : list of SE3 objects that govern the shape of the spline. + normalize_time : flag to map times into the range [0, 1] + bc_type : boundary condition provided to scipy CubicSpline backend. + string options: ["not-a-knot" (default), "clamped", "natural", "periodic"]. + For tuple options and details see the scipy docs link above. + """ + super().__init__() + self.control_poses = control_poses + self.timepoints = np.array(timepoints) + + if self.timepoints[-1] < self._e: + raise ValueError( + "Difference between start and end timepoints is less than {self._e}" + ) + + if len(self.control_poses) != len(self.timepoints): + raise ValueError("Length of control_poses and timepoints must be equal.") + + if len(self.timepoints) < 2: + raise ValueError("Need at least 2 data points to make a trajectory.") + + if normalize_time: + self.timepoints = self.timepoints - self.timepoints[0] + self.timepoints = self.timepoints / self.timepoints[-1] + + self.spline_xyz = CubicSpline( + self.timepoints, + np.array([pose.t for pose in self.control_poses]), + bc_type=bc_type, + ) + self.spline_so3 = RotationSpline( + self.timepoints, + Rotation.from_matrix(np.array([(pose.R) for pose in self.control_poses])), + ) + + def __call__(self, t: float) -> SE3: + """Compute function value at t. + Return: + pose: SE3 + """ + return SE3.Rt(t=self.spline_xyz(t), R=self.spline_so3(t).as_matrix()) + + def derivative(self, t: float) -> Twist3: + linear_vel = self.spline_xyz.derivative()(t) + angular_vel = self.spline_so3( + t, 1 + ) # 1 is angular rate, 2 is angular acceleration + return Twist3(linear_vel, angular_vel) + + +class SplineFit: + """A general class to fit various SE3 splines to data.""" + + def __init__( + self, + time_data: List[float], + pose_data: List[SE3], + ) -> None: + self.time_data = time_data + self.pose_data = pose_data + self.spline: Optional[SplineSE3] = None + + def stochastic_downsample_interpolation( + self, + epsilon_xyz: float = 1e-3, + epsilon_angle: float = 1e-1, + normalize_time: bool = True, + bc_type: str = "not-a-knot", + check_type: str = "local" + ) -> Tuple[InterpSplineSE3, List[int]]: + """ + Uses a random dropout to downsample a trajectory with an interpolated spline. Keeps the start and + end points of the trajectory. Takes a random order of the remaining indices, and then checks the error bound + of just that point if check_type=="local", checks the error of the whole trajectory is check_type=="global". + Local is **much** faster. + + Return: + downsampled interpolating spline, + list of removed indices from input data + """ + + interpolation_indices = list(range(len(self.pose_data))) + + # randomly attempt to remove poses from the trajectory + # always keep the start and end + removal_choices = interpolation_indices.copy() + removal_choices.remove(0) + removal_choices.remove(len(self.pose_data) - 1) + np.random.shuffle(removal_choices) + for candidate_removal_index in removal_choices: + interpolation_indices.remove(candidate_removal_index) + + self.spline = InterpSplineSE3( + [self.time_data[i] for i in interpolation_indices], + [self.pose_data[i] for i in interpolation_indices], + normalize_time=normalize_time, + bc_type=bc_type, + ) + + sample_time = self.time_data[candidate_removal_index] + if check_type is "local": + angular_error = SO3(self.pose_data[candidate_removal_index]).angdist( + SO3(self.spline.spline_so3(sample_time).as_matrix()) + ) + euclidean_error = np.linalg.norm( + self.pose_data[candidate_removal_index].t - self.spline.spline_xyz(sample_time) + ) + elif check_type is "global": + angular_error = self.max_angular_error() + euclidean_error = self.max_euclidean_error() + else: + raise ValueError(f"check_type must be 'local' of 'global', is {check_type}.") + + if (angular_error > epsilon_angle) or (euclidean_error > epsilon_xyz): + interpolation_indices.append(candidate_removal_index) + interpolation_indices.sort() + self.spline = InterpSplineSE3( + [self.time_data[i] for i in interpolation_indices], + [self.pose_data[i] for i in interpolation_indices], + normalize_time=normalize_time, + bc_type=bc_type, + ) + + return self.spline, interpolation_indices + + def max_angular_error(self) -> float: + return np.max(self.angular_errors()) + + def angular_errors(self) -> List[float]: + return [ + pose.angdist(self.spline(t)) + for pose, t in zip(self.pose_data, self.time_data) + ] + + def max_euclidean_error(self) -> float: + return np.max(self.euclidean_errors()) -class BSplineSE3: + def euclidean_errors(self) -> List[float]: + return [ + np.linalg.norm(pose.t - self.spline(t).t) + for pose, t in zip(self.pose_data, self.time_data) + ] + + +class BSplineSE3(SplineSE3): """A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline. The SE3 control poses are converted to se3 twists (the lie algebra) and a B-spline @@ -39,9 +261,9 @@ def __init__( - degree: int that controls degree of the polynomial that governs any given point on the spline. - knots: list of floats that govern which control points are active during evaluating the spline at a given t input. If none, they are automatically, uniformly generated based on number of control poses and - degree of spline. + degree of spline on the range [0,1]. """ - + super().__init__() self.control_poses = control_poses # a matrix where each row is a control pose as a twist @@ -74,32 +296,3 @@ def __call__(self, t: float) -> SE3: """ twist = np.hstack([spline(t) for spline in self.splines]) return SE3.Exp(twist) - - def visualize( - self, - num_samples: int, - length: float = 1.0, - repeat: bool = False, - ax: Optional[plt.Axes] = None, - kwargs_trplot: Dict[str, Any] = {"color": "green"}, - kwargs_tranimate: Dict[str, Any] = {"wait": True}, - kwargs_plot: Dict[str, Any] = {}, - ) -> None: - """Displays an animation of the trajectory with the control poses.""" - out_poses = [self(t) for t in np.linspace(0, 1, num_samples)] - x = [pose.x for pose in out_poses] - y = [pose.y for pose in out_poses] - z = [pose.z for pose in out_poses] - - if ax is None: - fig = plt.figure(figsize=(10, 10)) - ax = fig.add_subplot(projection="3d") - - trplot( - [np.array(self.control_poses)], ax=ax, length=length, **kwargs_trplot - ) # plot control points - ax.plot(x, y, z, **kwargs_plot) # plot x,y,z trajectory - - tranimate( - out_poses, repeat=repeat, length=length, **kwargs_tranimate - ) # animate pose along trajectory diff --git a/tests/test_spline.py b/tests/test_spline.py index f518fcfb..361bc28f 100644 --- a/tests/test_spline.py +++ b/tests/test_spline.py @@ -2,10 +2,8 @@ import numpy as np import matplotlib.pyplot as plt import unittest -import sys -import pytest -from spatialmath import BSplineSE3, SE3 +from spatialmath import BSplineSE3, SE3, InterpSplineSE3, SplineFit, SO3 class TestBSplineSE3(unittest.TestCase): @@ -29,4 +27,69 @@ def test_evaluation(self): def test_visualize(self): spline = BSplineSE3(self.control_poses) - spline.visualize(num_samples=100, repeat=False) + spline.visualize(sample_times= np.linspace(0, 1.0, 100), animate=True, repeat=False) + +class TestInterpSplineSE3: + waypoints = [ + SE3.Trans([e, 2 * np.cos(e / 2 * np.pi), 2 * np.sin(e / 2 * np.pi)]) + * SE3.Ry(e / 8 * np.pi) + for e in range(0, 8) + ] + time_horizon = 10 + times = np.linspace(0, time_horizon, len(waypoints)) + + @classmethod + def tearDownClass(cls): + plt.close("all") + + def test_constructor(self): + InterpSplineSE3(self.times, self.waypoints) + + def test_evaluation(self): + spline = InterpSplineSE3(self.times, self.waypoints) + for time, pose in zip(self.times, self.waypoints): + nt.assert_almost_equal(spline(time).angdist(pose), 0.0) + nt.assert_almost_equal(np.linalg.norm(spline(time).t - pose.t), 0.0) + + spline = InterpSplineSE3(self.times, self.waypoints, normalize_time=True) + norm_time = spline.timepoints + for time, pose in zip(norm_time, self.waypoints): + nt.assert_almost_equal(spline(time).angdist(pose), 0.0) + nt.assert_almost_equal(np.linalg.norm(spline(time).t - pose.t), 0.0) + + def test_small_delta_t(self): + InterpSplineSE3(np.linspace(0, InterpSplineSE3._e, len(self.waypoints)), self.waypoints) + + def test_visualize(self): + spline = InterpSplineSE3(self.times, self.waypoints) + spline.visualize(sample_times= np.linspace(0, self.time_horizon, 100), animate=True, repeat=False) + + +class TestSplineFit: + num_data_points = 300 + time_horizon = 5 + num_viz_points = 100 + + # make a helix + timestamps = np.linspace(0, 1, num_data_points) + trajectory = [ + SE3.Rt( + t=[t * 0.4, 0.4 * np.sin(t * 2 * np.pi * 0.5), 0.4 * np.cos(t * 2 * np.pi * 0.5)], + R=SO3.Rx(t * 2 * np.pi * 0.5), + ) + for t in timestamps * time_horizon + ] + + def test_spline_fit(self): + fit = SplineFit(self.timestamps, self.trajectory) + spline, kept_indices = fit.stochastic_downsample_interpolation() + + fraction_points_removed = 1.0 - len(kept_indices) / self.num_data_points + + assert(fraction_points_removed > 0.2) + assert(len(spline.control_poses)==len(kept_indices)) + assert(len(spline.timepoints)==len(kept_indices)) + + assert( fit.max_angular_error() < np.deg2rad(5.0) ) + assert( fit.max_angular_error() < 0.1 ) + spline.visualize(sample_times= np.linspace(0, self.time_horizon, 100), animate=True, repeat=False) \ No newline at end of file