2
2
# MIT Licence, see details in top-level file: LICENCE
3
3
4
4
"""
5
- Classes for parameterizing a trajectory in SE3 with B-splines.
6
-
7
- Copies parts of the API from scipy's B-spline class.
5
+ Classes for parameterizing a trajectory in SE3 with splines.
8
6
"""
9
7
10
- from typing import Any , Dict , List , Optional
11
- from scipy . interpolate import BSpline
12
- from spatialmath import SE3
13
- import numpy as np
8
+ from abc import ABC , abstractmethod
9
+ from functools import cached_property
10
+ from typing import List , Optional , Tuple , Set
11
+
14
12
import matplotlib .pyplot as plt
15
- from spatialmath .base .transforms3d import tranimate , trplot
13
+ import numpy as np
14
+ from scipy .interpolate import BSpline , CubicSpline
15
+ from scipy .spatial .transform import Rotation , RotationSpline
16
+
17
+ from spatialmath import SE3 , SO3 , Twist3
18
+ from spatialmath .base .transforms3d import tranimate
19
+
20
+
21
+ class SplineSE3 (ABC ):
22
+ def __init__ (self ) -> None :
23
+ self .control_poses : SE3
24
+
25
+ @abstractmethod
26
+ def __call__ (self , t : float ) -> SE3 :
27
+ pass
28
+
29
+ def visualize (
30
+ self ,
31
+ sample_times : List [float ],
32
+ input_trajectory : Optional [List [SE3 ]] = None ,
33
+ pose_marker_length : float = 0.2 ,
34
+ animate : bool = False ,
35
+ repeat : bool = True ,
36
+ ax : Optional [plt .Axes ] = None ,
37
+ ) -> None :
38
+ """Displays an animation of the trajectory with the control poses against an optional input trajectory.
39
+
40
+ Args:
41
+ sample_times: which times to sample the spline at and plot
42
+ """
43
+ if ax is None :
44
+ fig = plt .figure (figsize = (10 , 10 ))
45
+ ax = fig .add_subplot (projection = "3d" )
46
+
47
+ samples = [self (t ) for t in sample_times ]
48
+ if not animate :
49
+ pos = np .array ([pose .t for pose in samples ])
50
+ ax .plot (
51
+ pos [:, 0 ], pos [:, 1 ], pos [:, 2 ], "c" , linewidth = 1.0
52
+ ) # plot spline fit
53
+
54
+ pos = np .array ([pose .t for pose in self .control_poses ])
55
+ ax .plot (pos [:, 0 ], pos [:, 1 ], pos [:, 2 ], "r*" ) # plot control_poses
56
+
57
+ if input_trajectory is not None :
58
+ pos = np .array ([pose .t for pose in input_trajectory ])
59
+ ax .plot (
60
+ pos [:, 0 ], pos [:, 1 ], pos [:, 2 ], "go" , fillstyle = "none"
61
+ ) # plot compare to input poses
62
+
63
+ if animate :
64
+ tranimate (
65
+ samples , length = pose_marker_length , wait = True , repeat = repeat
66
+ ) # animate pose along trajectory
67
+ else :
68
+ plt .show ()
69
+
70
+
71
+ class InterpSplineSE3 (SplineSE3 ):
72
+ """Class for an interpolated trajectory in SE3, as a function of time, through control_poses with a cubic spline.
73
+
74
+ A combination of scipy.interpolate.CubicSpline and scipy.spatial.transform.RotationSpline (itself also cubic)
75
+ under the hood.
76
+ """
77
+
78
+ _e = 1e-12
79
+
80
+ def __init__ (
81
+ self ,
82
+ timepoints : List [float ],
83
+ control_poses : List [SE3 ],
84
+ * ,
85
+ normalize_time : bool = False ,
86
+ bc_type : str = "not-a-knot" , # not-a-knot is scipy default; None is invalid
87
+ ) -> None :
88
+ """Construct a InterpSplineSE3 object
89
+
90
+ Extends the scipy CubicSpline object
91
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html#cubicspline
92
+
93
+ Args :
94
+ timepoints : list of times corresponding to provided poses
95
+ control_poses : list of SE3 objects that govern the shape of the spline.
96
+ normalize_time : flag to map times into the range [0, 1]
97
+ bc_type : boundary condition provided to scipy CubicSpline backend.
98
+ string options: ["not-a-knot" (default), "clamped", "natural", "periodic"].
99
+ For tuple options and details see the scipy docs link above.
100
+ """
101
+ super ().__init__ ()
102
+ self .control_poses = control_poses
103
+ self .timepoints = np .array (timepoints )
104
+
105
+ if self .timepoints [- 1 ] < self ._e :
106
+ raise ValueError (
107
+ "Difference between start and end timepoints is less than {self._e}"
108
+ )
109
+
110
+ if len (self .control_poses ) != len (self .timepoints ):
111
+ raise ValueError ("Length of control_poses and timepoints must be equal." )
112
+
113
+ if len (self .timepoints ) < 2 :
114
+ raise ValueError ("Need at least 2 data points to make a trajectory." )
115
+
116
+ if normalize_time :
117
+ self .timepoints = self .timepoints - self .timepoints [0 ]
118
+ self .timepoints = self .timepoints / self .timepoints [- 1 ]
119
+
120
+ self .spline_xyz = CubicSpline (
121
+ self .timepoints ,
122
+ np .array ([pose .t for pose in self .control_poses ]),
123
+ bc_type = bc_type ,
124
+ )
125
+ self .spline_so3 = RotationSpline (
126
+ self .timepoints ,
127
+ Rotation .from_matrix (np .array ([(pose .R ) for pose in self .control_poses ])),
128
+ )
129
+
130
+ def __call__ (self , t : float ) -> SE3 :
131
+ """Compute function value at t.
132
+ Return:
133
+ pose: SE3
134
+ """
135
+ return SE3 .Rt (t = self .spline_xyz (t ), R = self .spline_so3 (t ).as_matrix ())
136
+
137
+ def derivative (self , t : float ) -> Twist3 :
138
+ linear_vel = self .spline_xyz .derivative ()(t )
139
+ angular_vel = self .spline_so3 (
140
+ t , 1
141
+ ) # 1 is angular rate, 2 is angular acceleration
142
+ return Twist3 (linear_vel , angular_vel )
143
+
144
+
145
+ class SplineFit :
146
+ """A general class to fit various SE3 splines to data."""
147
+
148
+ def __init__ (
149
+ self ,
150
+ time_data : List [float ],
151
+ pose_data : List [SE3 ],
152
+ ) -> None :
153
+ self .time_data = time_data
154
+ self .pose_data = pose_data
155
+ self .spline : Optional [SplineSE3 ] = None
156
+
157
+ def stochastic_downsample_interpolation (
158
+ self ,
159
+ epsilon_xyz : float = 1e-3 ,
160
+ epsilon_angle : float = 1e-1 ,
161
+ normalize_time : bool = True ,
162
+ bc_type : str = "not-a-knot" ,
163
+ check_type : str = "local"
164
+ ) -> Tuple [InterpSplineSE3 , List [int ]]:
165
+ """
166
+ Uses a random dropout to downsample a trajectory with an interpolated spline. Keeps the start and
167
+ end points of the trajectory. Takes a random order of the remaining indices, and then checks the error bound
168
+ of just that point if check_type=="local", checks the error of the whole trajectory is check_type=="global".
169
+ Local is **much** faster.
170
+
171
+ Return:
172
+ downsampled interpolating spline,
173
+ list of removed indices from input data
174
+ """
175
+
176
+ interpolation_indices = list (range (len (self .pose_data )))
177
+
178
+ # randomly attempt to remove poses from the trajectory
179
+ # always keep the start and end
180
+ removal_choices = interpolation_indices .copy ()
181
+ removal_choices .remove (0 )
182
+ removal_choices .remove (len (self .pose_data ) - 1 )
183
+ np .random .shuffle (removal_choices )
184
+ for candidate_removal_index in removal_choices :
185
+ interpolation_indices .remove (candidate_removal_index )
186
+
187
+ self .spline = InterpSplineSE3 (
188
+ [self .time_data [i ] for i in interpolation_indices ],
189
+ [self .pose_data [i ] for i in interpolation_indices ],
190
+ normalize_time = normalize_time ,
191
+ bc_type = bc_type ,
192
+ )
193
+
194
+ sample_time = self .time_data [candidate_removal_index ]
195
+ if check_type is "local" :
196
+ angular_error = SO3 (self .pose_data [candidate_removal_index ]).angdist (
197
+ SO3 (self .spline .spline_so3 (sample_time ).as_matrix ())
198
+ )
199
+ euclidean_error = np .linalg .norm (
200
+ self .pose_data [candidate_removal_index ].t - self .spline .spline_xyz (sample_time )
201
+ )
202
+ elif check_type is "global" :
203
+ angular_error = self .max_angular_error ()
204
+ euclidean_error = self .max_euclidean_error ()
205
+ else :
206
+ raise ValueError (f"check_type must be 'local' of 'global', is { check_type } ." )
207
+
208
+ if (angular_error > epsilon_angle ) or (euclidean_error > epsilon_xyz ):
209
+ interpolation_indices .append (candidate_removal_index )
210
+ interpolation_indices .sort ()
16
211
212
+ self .spline = InterpSplineSE3 (
213
+ [self .time_data [i ] for i in interpolation_indices ],
214
+ [self .pose_data [i ] for i in interpolation_indices ],
215
+ normalize_time = normalize_time ,
216
+ bc_type = bc_type ,
217
+ )
218
+
219
+ return self .spline , interpolation_indices
220
+
221
+ def max_angular_error (self ) -> float :
222
+ return np .max (self .angular_errors ())
223
+
224
+ def angular_errors (self ) -> List [float ]:
225
+ return [
226
+ pose .angdist (self .spline (t ))
227
+ for pose , t in zip (self .pose_data , self .time_data )
228
+ ]
229
+
230
+ def max_euclidean_error (self ) -> float :
231
+ return np .max (self .euclidean_errors ())
17
232
18
- class BSplineSE3 :
233
+ def euclidean_errors (self ) -> List [float ]:
234
+ return [
235
+ np .linalg .norm (pose .t - self .spline (t ).t )
236
+ for pose , t in zip (self .pose_data , self .time_data )
237
+ ]
238
+
239
+
240
+ class BSplineSE3 (SplineSE3 ):
19
241
"""A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline.
20
242
21
243
The SE3 control poses are converted to se3 twists (the lie algebra) and a B-spline
@@ -39,9 +261,9 @@ def __init__(
39
261
- degree: int that controls degree of the polynomial that governs any given point on the spline.
40
262
- knots: list of floats that govern which control points are active during evaluating the spline
41
263
at a given t input. If none, they are automatically, uniformly generated based on number of control poses and
42
- degree of spline.
264
+ degree of spline on the range [0,1] .
43
265
"""
44
-
266
+ super (). __init__ ()
45
267
self .control_poses = control_poses
46
268
47
269
# a matrix where each row is a control pose as a twist
@@ -74,32 +296,3 @@ def __call__(self, t: float) -> SE3:
74
296
"""
75
297
twist = np .hstack ([spline (t ) for spline in self .splines ])
76
298
return SE3 .Exp (twist )
77
-
78
- def visualize (
79
- self ,
80
- num_samples : int ,
81
- length : float = 1.0 ,
82
- repeat : bool = False ,
83
- ax : Optional [plt .Axes ] = None ,
84
- kwargs_trplot : Dict [str , Any ] = {"color" : "green" },
85
- kwargs_tranimate : Dict [str , Any ] = {"wait" : True },
86
- kwargs_plot : Dict [str , Any ] = {},
87
- ) -> None :
88
- """Displays an animation of the trajectory with the control poses."""
89
- out_poses = [self (t ) for t in np .linspace (0 , 1 , num_samples )]
90
- x = [pose .x for pose in out_poses ]
91
- y = [pose .y for pose in out_poses ]
92
- z = [pose .z for pose in out_poses ]
93
-
94
- if ax is None :
95
- fig = plt .figure (figsize = (10 , 10 ))
96
- ax = fig .add_subplot (projection = "3d" )
97
-
98
- trplot (
99
- [np .array (self .control_poses )], ax = ax , length = length , ** kwargs_trplot
100
- ) # plot control points
101
- ax .plot (x , y , z , ** kwargs_plot ) # plot x,y,z trajectory
102
-
103
- tranimate (
104
- out_poses , repeat = repeat , length = length , ** kwargs_tranimate
105
- ) # animate pose along trajectory
0 commit comments