Skip to content

Commit a6d6641

Browse files
Interpolating spline (#141)
1 parent 0c57dc0 commit a6d6641

File tree

5 files changed

+325
-52
lines changed

5 files changed

+325
-52
lines changed

.pre-commit-config.yaml

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
repos:
2-
# - repo: https://github.com/charliermarsh/ruff-pre-commit
3-
# # Ruff version.
4-
# rev: 'v0.1.0'
5-
# hooks:
6-
# - id: ruff
7-
# args: ['--fix', '--config', 'pyproject.toml']
2+
- repo: https://github.com/charliermarsh/ruff-pre-commit
3+
# Ruff version.
4+
rev: 'v0.1.0'
5+
hooks:
6+
- id: ruff
7+
args: ['--fix', '--config', 'pyproject.toml']
88

99
- repo: https://github.com/psf/black
1010
rev: 23.10.0
@@ -14,6 +14,21 @@ repos:
1414
args: ['--config', 'pyproject.toml']
1515
verbose: true
1616

17+
- repo: https://github.com/pre-commit/pre-commit-hooks
18+
rev: v4.5.0
19+
hooks:
20+
- id: end-of-file-fixer
21+
- id: debug-statements # Ensure we don't commit `import pdb; pdb.set_trace()`
22+
exclude: |
23+
(?x)^(
24+
docker/ros/web/static/.*|
25+
)$
26+
- id: trailing-whitespace
27+
exclude: |
28+
(?x)^(
29+
docker/ros/web/static/.*|
30+
(.*/).*\.patch|
31+
)$
1732
# - repo: https://github.com/pre-commit/mirrors-mypy
1833
# rev: v1.6.1
1934
# hooks:

spatialmath/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from spatialmath.quaternion import Quaternion, UnitQuaternion
1717
from spatialmath.DualQuaternion import DualQuaternion, UnitDualQuaternion
18-
from spatialmath.spline import BSplineSE3
18+
from spatialmath.spline import BSplineSE3, InterpSplineSE3, SplineFit
1919

2020
# from spatialmath.Plucker import *
2121
# from spatialmath import base as smb
@@ -45,6 +45,8 @@
4545
"Polygon2",
4646
"Ellipse",
4747
"BSplineSE3",
48+
"InterpSplineSE3",
49+
"SplineFit"
4850
]
4951

5052
try:

spatialmath/base/animate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def update(frame, animation):
217217
if isinstance(frame, float):
218218
# passed a single transform, interpolate it
219219
T = smb.trinterp(start=self.start, end=self.end, s=frame)
220-
elif isinstance(frame, NDArray):
220+
elif isinstance(frame, np.ndarray):
221221
# type is SO3Array or SE3Array when Animate.trajectory is not None
222222
T = frame
223223
else:

spatialmath/spline.py

Lines changed: 233 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,242 @@
22
# MIT Licence, see details in top-level file: LICENCE
33

44
"""
5-
Classes for parameterizing a trajectory in SE3 with B-splines.
6-
7-
Copies parts of the API from scipy's B-spline class.
5+
Classes for parameterizing a trajectory in SE3 with splines.
86
"""
97

10-
from typing import Any, Dict, List, Optional
11-
from scipy.interpolate import BSpline
12-
from spatialmath import SE3
13-
import numpy as np
8+
from abc import ABC, abstractmethod
9+
from functools import cached_property
10+
from typing import List, Optional, Tuple, Set
11+
1412
import matplotlib.pyplot as plt
15-
from spatialmath.base.transforms3d import tranimate, trplot
13+
import numpy as np
14+
from scipy.interpolate import BSpline, CubicSpline
15+
from scipy.spatial.transform import Rotation, RotationSpline
16+
17+
from spatialmath import SE3, SO3, Twist3
18+
from spatialmath.base.transforms3d import tranimate
19+
20+
21+
class SplineSE3(ABC):
22+
def __init__(self) -> None:
23+
self.control_poses: SE3
24+
25+
@abstractmethod
26+
def __call__(self, t: float) -> SE3:
27+
pass
28+
29+
def visualize(
30+
self,
31+
sample_times: List[float],
32+
input_trajectory: Optional[List[SE3]] = None,
33+
pose_marker_length: float = 0.2,
34+
animate: bool = False,
35+
repeat: bool = True,
36+
ax: Optional[plt.Axes] = None,
37+
) -> None:
38+
"""Displays an animation of the trajectory with the control poses against an optional input trajectory.
39+
40+
Args:
41+
sample_times: which times to sample the spline at and plot
42+
"""
43+
if ax is None:
44+
fig = plt.figure(figsize=(10, 10))
45+
ax = fig.add_subplot(projection="3d")
46+
47+
samples = [self(t) for t in sample_times]
48+
if not animate:
49+
pos = np.array([pose.t for pose in samples])
50+
ax.plot(
51+
pos[:, 0], pos[:, 1], pos[:, 2], "c", linewidth=1.0
52+
) # plot spline fit
53+
54+
pos = np.array([pose.t for pose in self.control_poses])
55+
ax.plot(pos[:, 0], pos[:, 1], pos[:, 2], "r*") # plot control_poses
56+
57+
if input_trajectory is not None:
58+
pos = np.array([pose.t for pose in input_trajectory])
59+
ax.plot(
60+
pos[:, 0], pos[:, 1], pos[:, 2], "go", fillstyle="none"
61+
) # plot compare to input poses
62+
63+
if animate:
64+
tranimate(
65+
samples, length=pose_marker_length, wait=True, repeat=repeat
66+
) # animate pose along trajectory
67+
else:
68+
plt.show()
69+
70+
71+
class InterpSplineSE3(SplineSE3):
72+
"""Class for an interpolated trajectory in SE3, as a function of time, through control_poses with a cubic spline.
73+
74+
A combination of scipy.interpolate.CubicSpline and scipy.spatial.transform.RotationSpline (itself also cubic)
75+
under the hood.
76+
"""
77+
78+
_e = 1e-12
79+
80+
def __init__(
81+
self,
82+
timepoints: List[float],
83+
control_poses: List[SE3],
84+
*,
85+
normalize_time: bool = False,
86+
bc_type: str = "not-a-knot", # not-a-knot is scipy default; None is invalid
87+
) -> None:
88+
"""Construct a InterpSplineSE3 object
89+
90+
Extends the scipy CubicSpline object
91+
https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html#cubicspline
92+
93+
Args :
94+
timepoints : list of times corresponding to provided poses
95+
control_poses : list of SE3 objects that govern the shape of the spline.
96+
normalize_time : flag to map times into the range [0, 1]
97+
bc_type : boundary condition provided to scipy CubicSpline backend.
98+
string options: ["not-a-knot" (default), "clamped", "natural", "periodic"].
99+
For tuple options and details see the scipy docs link above.
100+
"""
101+
super().__init__()
102+
self.control_poses = control_poses
103+
self.timepoints = np.array(timepoints)
104+
105+
if self.timepoints[-1] < self._e:
106+
raise ValueError(
107+
"Difference between start and end timepoints is less than {self._e}"
108+
)
109+
110+
if len(self.control_poses) != len(self.timepoints):
111+
raise ValueError("Length of control_poses and timepoints must be equal.")
112+
113+
if len(self.timepoints) < 2:
114+
raise ValueError("Need at least 2 data points to make a trajectory.")
115+
116+
if normalize_time:
117+
self.timepoints = self.timepoints - self.timepoints[0]
118+
self.timepoints = self.timepoints / self.timepoints[-1]
119+
120+
self.spline_xyz = CubicSpline(
121+
self.timepoints,
122+
np.array([pose.t for pose in self.control_poses]),
123+
bc_type=bc_type,
124+
)
125+
self.spline_so3 = RotationSpline(
126+
self.timepoints,
127+
Rotation.from_matrix(np.array([(pose.R) for pose in self.control_poses])),
128+
)
129+
130+
def __call__(self, t: float) -> SE3:
131+
"""Compute function value at t.
132+
Return:
133+
pose: SE3
134+
"""
135+
return SE3.Rt(t=self.spline_xyz(t), R=self.spline_so3(t).as_matrix())
136+
137+
def derivative(self, t: float) -> Twist3:
138+
linear_vel = self.spline_xyz.derivative()(t)
139+
angular_vel = self.spline_so3(
140+
t, 1
141+
) # 1 is angular rate, 2 is angular acceleration
142+
return Twist3(linear_vel, angular_vel)
143+
144+
145+
class SplineFit:
146+
"""A general class to fit various SE3 splines to data."""
147+
148+
def __init__(
149+
self,
150+
time_data: List[float],
151+
pose_data: List[SE3],
152+
) -> None:
153+
self.time_data = time_data
154+
self.pose_data = pose_data
155+
self.spline: Optional[SplineSE3] = None
156+
157+
def stochastic_downsample_interpolation(
158+
self,
159+
epsilon_xyz: float = 1e-3,
160+
epsilon_angle: float = 1e-1,
161+
normalize_time: bool = True,
162+
bc_type: str = "not-a-knot",
163+
check_type: str = "local"
164+
) -> Tuple[InterpSplineSE3, List[int]]:
165+
"""
166+
Uses a random dropout to downsample a trajectory with an interpolated spline. Keeps the start and
167+
end points of the trajectory. Takes a random order of the remaining indices, and then checks the error bound
168+
of just that point if check_type=="local", checks the error of the whole trajectory is check_type=="global".
169+
Local is **much** faster.
170+
171+
Return:
172+
downsampled interpolating spline,
173+
list of removed indices from input data
174+
"""
175+
176+
interpolation_indices = list(range(len(self.pose_data)))
177+
178+
# randomly attempt to remove poses from the trajectory
179+
# always keep the start and end
180+
removal_choices = interpolation_indices.copy()
181+
removal_choices.remove(0)
182+
removal_choices.remove(len(self.pose_data) - 1)
183+
np.random.shuffle(removal_choices)
184+
for candidate_removal_index in removal_choices:
185+
interpolation_indices.remove(candidate_removal_index)
186+
187+
self.spline = InterpSplineSE3(
188+
[self.time_data[i] for i in interpolation_indices],
189+
[self.pose_data[i] for i in interpolation_indices],
190+
normalize_time=normalize_time,
191+
bc_type=bc_type,
192+
)
193+
194+
sample_time = self.time_data[candidate_removal_index]
195+
if check_type is "local":
196+
angular_error = SO3(self.pose_data[candidate_removal_index]).angdist(
197+
SO3(self.spline.spline_so3(sample_time).as_matrix())
198+
)
199+
euclidean_error = np.linalg.norm(
200+
self.pose_data[candidate_removal_index].t - self.spline.spline_xyz(sample_time)
201+
)
202+
elif check_type is "global":
203+
angular_error = self.max_angular_error()
204+
euclidean_error = self.max_euclidean_error()
205+
else:
206+
raise ValueError(f"check_type must be 'local' of 'global', is {check_type}.")
207+
208+
if (angular_error > epsilon_angle) or (euclidean_error > epsilon_xyz):
209+
interpolation_indices.append(candidate_removal_index)
210+
interpolation_indices.sort()
16211

212+
self.spline = InterpSplineSE3(
213+
[self.time_data[i] for i in interpolation_indices],
214+
[self.pose_data[i] for i in interpolation_indices],
215+
normalize_time=normalize_time,
216+
bc_type=bc_type,
217+
)
218+
219+
return self.spline, interpolation_indices
220+
221+
def max_angular_error(self) -> float:
222+
return np.max(self.angular_errors())
223+
224+
def angular_errors(self) -> List[float]:
225+
return [
226+
pose.angdist(self.spline(t))
227+
for pose, t in zip(self.pose_data, self.time_data)
228+
]
229+
230+
def max_euclidean_error(self) -> float:
231+
return np.max(self.euclidean_errors())
17232

18-
class BSplineSE3:
233+
def euclidean_errors(self) -> List[float]:
234+
return [
235+
np.linalg.norm(pose.t - self.spline(t).t)
236+
for pose, t in zip(self.pose_data, self.time_data)
237+
]
238+
239+
240+
class BSplineSE3(SplineSE3):
19241
"""A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline.
20242
21243
The SE3 control poses are converted to se3 twists (the lie algebra) and a B-spline
@@ -39,9 +261,9 @@ def __init__(
39261
- degree: int that controls degree of the polynomial that governs any given point on the spline.
40262
- knots: list of floats that govern which control points are active during evaluating the spline
41263
at a given t input. If none, they are automatically, uniformly generated based on number of control poses and
42-
degree of spline.
264+
degree of spline on the range [0,1].
43265
"""
44-
266+
super().__init__()
45267
self.control_poses = control_poses
46268

47269
# a matrix where each row is a control pose as a twist
@@ -74,32 +296,3 @@ def __call__(self, t: float) -> SE3:
74296
"""
75297
twist = np.hstack([spline(t) for spline in self.splines])
76298
return SE3.Exp(twist)
77-
78-
def visualize(
79-
self,
80-
num_samples: int,
81-
length: float = 1.0,
82-
repeat: bool = False,
83-
ax: Optional[plt.Axes] = None,
84-
kwargs_trplot: Dict[str, Any] = {"color": "green"},
85-
kwargs_tranimate: Dict[str, Any] = {"wait": True},
86-
kwargs_plot: Dict[str, Any] = {},
87-
) -> None:
88-
"""Displays an animation of the trajectory with the control poses."""
89-
out_poses = [self(t) for t in np.linspace(0, 1, num_samples)]
90-
x = [pose.x for pose in out_poses]
91-
y = [pose.y for pose in out_poses]
92-
z = [pose.z for pose in out_poses]
93-
94-
if ax is None:
95-
fig = plt.figure(figsize=(10, 10))
96-
ax = fig.add_subplot(projection="3d")
97-
98-
trplot(
99-
[np.array(self.control_poses)], ax=ax, length=length, **kwargs_trplot
100-
) # plot control points
101-
ax.plot(x, y, z, **kwargs_plot) # plot x,y,z trajectory
102-
103-
tranimate(
104-
out_poses, repeat=repeat, length=length, **kwargs_tranimate
105-
) # animate pose along trajectory

0 commit comments

Comments
 (0)