Skip to content

Commit cce27d0

Browse files
author
David Butterworth
committed
Change GetLinearCollisionCheckPts() to directly take a sampling function
1 parent a1e375d commit cce27d0

File tree

2 files changed

+32
-46
lines changed

2 files changed

+32
-46
lines changed

src/prpy/util.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,7 +1191,7 @@ def GetCollisionCheckPts(robot, traj, include_start=True, start_time=0.,
11911191
dt = 2. * dt
11921192

11931193

1194-
def GetLinearCollisionCheckPts(robot, traj, norm_order=2, sampling_order=None):
1194+
def GetLinearCollisionCheckPts(robot, traj, norm_order=2, sampling_func=None):
11951195
"""
11961196
For a piece-wise linear trajectory, generate a list
11971197
of configuration pairs that need to be collision checked.
@@ -1205,9 +1205,11 @@ def GetLinearCollisionCheckPts(robot, traj, norm_order=2, sampling_order=None):
12051205
@param int norm_order: 1 ==> The L1 norm
12061206
2 ==> The L2 norm
12071207
inf ==> The L_infinity norm
1208-
@param string sampling_order:
1209-
'linear' : Sample in a linear sequence
1210-
'van_der_corput' : Sample in an optimal sequence
1208+
@param generator sampling_func A function that returns a sequence of
1209+
sample times.
1210+
e.g. SampleTimeGenerator()
1211+
or
1212+
VanDerCorputSampleGenerator()
12111213
12121214
@returns generator: A tuple (t,q) of float values, being the sample
12131215
time and joint configuration.
@@ -1271,7 +1273,7 @@ def GetLinearCollisionCheckPts(robot, traj, norm_order=2, sampling_order=None):
12711273
q1 = traj_cspec.ExtractJointValues(waypoint, robot, dof_indices)
12721274
dq = numpy.abs(q1 - q0)
12731275
max_diff_float = numpy.max( numpy.abs(q1 - q0) / q_resolutions)
1274-
1276+
12751277
# Get the number of steps (as a float) required for
12761278
# each joint at DOF resolution
12771279
num_steps = dq / q_resolutions
@@ -1302,20 +1304,13 @@ def GetLinearCollisionCheckPts(robot, traj, norm_order=2, sampling_order=None):
13021304

13031305
traj_duration = temp_traj.GetDuration()
13041306

1305-
# Sample the trajectory using the specified sequence
1306-
if sampling_order == None:
1307-
sampling_order = 'linear'
1307+
# Sample the trajectory using the specified sample generator
13081308
seq = None
1309-
if sampling_order == 'van_der_corput':
1310-
# An approximate Van der Corput sequence, between the
1311-
# start and end points
1312-
seq = VanDerCorputSampleGenerator(0, traj_duration, step=2)
1313-
elif sampling_order == 'linear':
1309+
if sampling_func == None:
13141310
# (default) Linear sequence, from start to end
13151311
seq = SampleTimeGenerator(0, traj_duration, step=2)
13161312
else:
1317-
error = "Unknown sampling_order '" + sampling_order + "' "
1318-
raise ValueError(error)
1313+
seq = sampling_func(0, traj_duration, step=2)
13191314

13201315
# Sample the trajectory in time
13211316
# and return time value and joint positions

tests/test_util.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,12 @@ def test_GetLinearCollisionCheckPts_SinglePointTraj(self):
250250
q0 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
251251
desired_num_checks = 1
252252
traj = self.CreateTrajectory(q0, q0) # makes traj with 1 waypoint
253+
# Linear sampling
254+
linear = prpy.util.SampleTimeGenerator
253255
checks = prpy.util.GetLinearCollisionCheckPts(self.robot, \
254256
traj, \
255257
norm_order=2, \
256-
sampling_order='linear')
258+
sampling_func=linear)
257259
num_checks = sum(1 for x in checks)
258260
if num_checks != desired_num_checks:
259261
error = str(num_checks) + ' is the wrong number of check pts.'
@@ -270,10 +272,12 @@ def test_GetLinearCollisionCheckPts_TwoPointTraj_LessThanDOFRes(self):
270272
q1 = 0.5 * self.dof_resolutions
271273
desired_num_checks = 2
272274
traj = self.CreateTrajectory(q0, q1)
275+
# Linear sampling
276+
linear = prpy.util.SampleTimeGenerator
273277
checks = prpy.util.GetLinearCollisionCheckPts(self.robot, \
274278
traj, \
275279
norm_order=2, \
276-
sampling_order='linear')
280+
sampling_func=linear)
277281
num_checks = 0
278282
for t, q in checks:
279283
num_checks = num_checks + 1
@@ -301,10 +305,12 @@ def test_GetLinearCollisionCheckPts_TwoPointTraj_EqualToDOFRes(self):
301305
q1 = 1.0 * self.dof_resolutions
302306
desired_num_checks = 2
303307
traj = self.CreateTrajectory(q0, q1)
308+
# Linear sampling
309+
linear = prpy.util.SampleTimeGenerator
304310
checks = prpy.util.GetLinearCollisionCheckPts(self.robot, \
305311
traj, \
306312
norm_order=2, \
307-
sampling_order='linear')
313+
sampling_func=linear)
308314
num_checks = 0
309315
for t, q in checks:
310316
num_checks = num_checks + 1
@@ -325,10 +331,12 @@ def test_GetLinearCollisionCheckPts_TwoPointTraj_GreaterThanDOFRes(self):
325331
q1 = 1.2 * self.dof_resolutions
326332
desired_num_checks = 3
327333
traj = self.CreateTrajectory(q0, q1)
334+
# Linear sampling
335+
linear = prpy.util.SampleTimeGenerator
328336
checks = prpy.util.GetLinearCollisionCheckPts(self.robot, \
329337
traj, \
330338
norm_order=2, \
331-
sampling_order='linear')
339+
sampling_func=linear)
332340
num_checks = 0
333341
for t, q in checks:
334342
num_checks = num_checks + 1
@@ -349,60 +357,43 @@ def test_GetLinearCollisionCheckPts_SamplingOrderMethods(self):
349357
q1 = [0.01, 0.02, 0.03, 0.04, 0.03, 0.02, 0.01]
350358
traj = self.CreateTrajectory(q0, q1)
351359

352-
# Linear
353-
method = 'linear'
360+
# Linear sampling
361+
linear = prpy.util.SampleTimeGenerator
354362
checks = prpy.util.GetLinearCollisionCheckPts(self.robot, \
355363
traj, \
356364
norm_order=2, \
357-
sampling_order=method)
365+
sampling_func=linear)
358366
try:
359367
# Exception can be thrown only when we try to use the generator
360368
num_checks = sum(1 for x in checks)
361369
except ValueError:
362-
error = "Unknown sampling_order method: '" + method + "' "
370+
error = "Unknown sampling_func: '" + method + "' "
363371
self.fail(error)
364372
except Exception, e:
365373
error = 'Unexpected exception thrown: ' + str(e.message)
366374
self.fail(error)
367375
else:
368376
pass # test passed
369377

370-
# Van der Corput
371-
method = 'van_der_corput'
378+
# An approximate Van der Corput sequence, between the
379+
# start and end points
380+
vdc = prpy.util.VanDerCorputSampleGenerator
372381
checks = prpy.util.GetLinearCollisionCheckPts(self.robot, \
373382
traj, \
374383
norm_order=2, \
375-
sampling_order=method)
384+
sampling_func=vdc)
376385
try:
377386
# Exception can be thrown only when we try to use the generator
378387
num_checks = sum(1 for x in checks)
379388
except ValueError:
380-
error = "Unknown sampling_order method: '" + method + "' "
389+
error = "Unknown sampling_func: '" + method + "' "
381390
self.fail(error)
382391
except Exception, e:
383392
error = 'Unexpected exception thrown: ' + str(e.message)
384393
self.fail(error)
385394
else:
386395
pass # test passed
387396

388-
# unknown method
389-
method = 'garbage'
390-
checks = prpy.util.GetLinearCollisionCheckPts(self.robot, \
391-
traj, \
392-
norm_order=2, \
393-
sampling_order=method)
394-
try:
395-
# Exception can be thrown only when we try to use the generator
396-
num_checks = sum(1 for x in checks)
397-
except ValueError:
398-
pass # test passed, an exception was thrown
399-
400-
except Exception, e:
401-
error = 'Unexpected exception thrown: ' + str(e.message)
402-
self.fail(error)
403-
else:
404-
self.fail('Expected exception not thrown')
405-
406397

407398
# ConvertIntToBinaryString()
408399

@@ -501,7 +492,7 @@ def test_VanDerCorputSampleGenerator_ExcessLessThanHalfStep(self):
501492
verbose=True)
502493

503494
def test_VanDerCorputSampleGenerator_ExcessEqualToHalfStep(self):
504-
# Check that the end-point is NOT included when it's distance
495+
# Check that the end-point is NOT included when it's distance
505496
# is EQUAL too (or less than) the step-size.
506497
expected_sequence = [0.0, 12.0, 6.0, 4.0, 8.0, 2.0, 10.0]
507498
traj_dur = 13.0

0 commit comments

Comments
 (0)