Skip to content

Commit 449ead2

Browse files
committed
added ICP
more doco
1 parent 3eeb080 commit 449ead2

File tree

1 file changed

+163
-9
lines changed

1 file changed

+163
-9
lines changed

spatialmath/base/transforms2d.py

Lines changed: 163 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def xyt2tr(xyt, unit="rad"):
106106
:type xyt: array_like(3)
107107
:param unit: angular units: 'rad' [default], or 'deg'
108108
:type unit: str
109-
:return: 3x3 homogeneous transformation matrix
109+
:return: SE(2) matrix
110110
:rtype: ndarray(3,3)
111111
112112
- ``xyt2tr([x,y,θ])`` is a homogeneous transformation (3x3) representing a rotation of
@@ -170,8 +170,8 @@ def transl2(x, y=None):
170170
:type x: float
171171
:param y: translation along Y-axis
172172
:type y: float
173-
:return: SE(2) transform matrix or the translation elements of a homogeneous
174-
transform :rtype: ndarray(3,3)
173+
:return: SE(2) matrix
174+
:rtype: ndarray(3,3)
175175
176176
- ``T = transl2([X, Y])`` is an SE(2) homogeneous transform (3x3)
177177
representing a pure translation.
@@ -584,14 +584,14 @@ def trinterp2(start, end, s=None):
584584
:rtype: ndarray(3,3) or ndarray(2,2)
585585
:raises ValueError: bad arguments
586586
587-
- ``trinterp2(None, T, S)`` is a homogeneous transform (3x3) interpolated
588-
between identity when S=0 and T (3x3) when S=1.
587+
- ``trinterp2(None, T, S)`` is an SE(2) matrix interpolated
588+
between identity when `S`=0 and `T` when `S`=1.
589589
- ``trinterp2(T0, T1, S)`` as above but interpolated
590-
between T0 (3x3) when S=0 and T1 (3x3) when S=1.
591-
- ``trinterp2(None, R, S)`` is a rotation matrix (2x2) interpolated
592-
between identity when S=0 and R (2x2) when S=1.
590+
between `T0` when `S`=0 and `T1` when `S`=1.
591+
- ``trinterp2(None, R, S)`` is an SO(2) matrix interpolated
592+
between identity when `S`=0 and `R` when `S`=1.
593593
- ``trinterp2(R0, R1, S)`` as above but interpolated
594-
between R0 (2x2) when S=0 and R1 (2x2) when S=1.
594+
between `R0` when `S`=0 and `R1` when `S`=1.
595595
596596
.. note:: Rotation angle is linearly interpolated.
597597
@@ -777,6 +777,160 @@ def points2tr2(p1, p2):
777777

778778
return T
779779

780+
# https://github.com/ClayFlannigan/icp/blob/master/icp.py
781+
# https://github.com/1988kramer/intel_dataset/blob/master/scripts/Align2D.py
782+
# hack below to use points2tr above
783+
# use ClayFlannigan's improved data association
784+
from scipy.spatial import KDTree
785+
import numpy as np
786+
787+
# reference or target 2xN
788+
# source 2xN
789+
790+
# params:
791+
# source_points: numpy array containing points to align to the reference set
792+
# points should be homogeneous, with one point per row
793+
# reference_points: numpy array containing points to which the source points
794+
# are to be aligned, points should be homogeneous with one
795+
# point per row
796+
# initial_T: initial estimate of the transform between reference and source
797+
# def __init__(self, source_points, reference_points, initial_T):
798+
# self.source = source_points
799+
# self.reference = reference_points
800+
# self.init_T = initial_T
801+
# self.reference_tree = KDTree(reference_points[:,:2])
802+
# self.transform = self.AlignICP(30, 1.0e-4)
803+
804+
# uses the iterative closest point algorithm to find the
805+
# transformation between the source and reference point clouds
806+
# that minimizes the sum of squared errors between nearest
807+
# neighbors in the two point clouds
808+
# params:
809+
# max_iter: int, max number of iterations
810+
# min_delta_err: float, minimum change in alignment error
811+
def ICP2d(reference, source, T=None, max_iter=20, min_delta_err=1e-4):
812+
813+
mean_sq_error = 1.0e6 # initialize error as large number
814+
delta_err = 1.0e6 # change in error (used in stopping condition)
815+
num_iter = 0 # number of iterations
816+
if T is None:
817+
T = np.eye(3)
818+
819+
ref_kdtree = KDTree(reference.T)
820+
tf_source = source
821+
822+
source_hom = np.vstack((source, np.ones(source.shape[1])))
823+
824+
while delta_err > min_delta_err and num_iter < max_iter:
825+
826+
# find correspondences via nearest-neighbor search
827+
matched_ref_pts, matched_source, indices = _FindCorrespondences(ref_kdtree, tf_source, reference)
828+
829+
# find alingment between source and corresponding reference points via SVD
830+
# note: svd step doesn't use homogeneous points
831+
new_T = _AlignSVD(matched_source, matched_ref_pts)
832+
833+
# update transformation between point sets
834+
T = T @ new_T
835+
836+
# apply transformation to the source points
837+
tf_source = T @ source_hom
838+
tf_source = tf_source[:2, :]
839+
840+
# find mean squared error between transformed source points and reference points
841+
# TODO: do this with fancy indexing
842+
new_err = 0
843+
for i in range(len(indices)):
844+
if indices[i] != -1:
845+
diff = tf_source[:, i] - reference[:, indices[i]]
846+
new_err += np.dot(diff,diff.T)
847+
848+
new_err /= float(len(matched_ref_pts))
849+
850+
# update error and calculate delta error
851+
delta_err = abs(mean_sq_error - new_err)
852+
mean_sq_error = new_err
853+
print('ITER', num_iter, delta_err, mean_sq_error)
854+
855+
num_iter += 1
856+
857+
return T
858+
859+
860+
def _FindCorrespondences(tree, source, reference):
861+
862+
# get distances to nearest neighbors and indices of nearest neighbors
863+
dist, indices = tree.query(source.T)
864+
865+
# remove multiple associatons from index list
866+
# only retain closest associations
867+
unique = False
868+
matched_src = source.copy()
869+
while not unique:
870+
unique = True
871+
for i, idxi in enumerate(indices):
872+
if idxi == -1:
873+
continue
874+
# could do this with np.nonzero
875+
for j in range(i+1,len(indices)):
876+
if idxi == indices[j]:
877+
if dist[i] < dist[j]:
878+
indices[j] = -1
879+
else:
880+
indices[i] = -1
881+
break
882+
# build array of nearest neighbor reference points
883+
# and remove unmatched source points
884+
point_list = []
885+
src_idx = 0
886+
for idx in indices:
887+
if idx != -1:
888+
point_list.append(reference[:,idx])
889+
src_idx += 1
890+
else:
891+
matched_src = np.delete(matched_src, src_idx, axis=1)
892+
893+
matched_ref = np.array(point_list).T
894+
895+
return matched_ref, matched_src, indices
896+
897+
# uses singular value decomposition to find the
898+
# transformation from the reference to the source point cloud
899+
# assumes source and reference point clounds are ordered such that
900+
# corresponding points are at the same indices in each array
901+
#
902+
# params:
903+
# source: numpy array representing source pointcloud
904+
# reference: numpy array representing reference pointcloud
905+
# returns:
906+
# T: transformation between the two point clouds
907+
908+
# TODO: replace this func with
909+
def _AlignSVD(source, reference):
910+
911+
# first find the centroids of both point clouds
912+
src_centroid = source.mean(axis=1)
913+
ref_centroid = reference.mean(axis=1)
914+
915+
# get the point clouds in reference to their centroids
916+
source_centered = source - src_centroid[:, np.newaxis]
917+
reference_centered = reference - ref_centroid[:, np.newaxis]
918+
919+
# compute the moment matrix
920+
M = reference_centered @ source_centered.T
921+
922+
# do the singular value decomposition
923+
U, W, V_t = np.linalg.svd(M)
924+
925+
# get rotation between the two point clouds
926+
R = U @ V_t
927+
if np.linalg.det(R) < 0:
928+
raise RuntimeError('bad rotation matrix')
929+
930+
# translation is the difference between the point clound centroids
931+
t = ref_centroid - R @ src_centroid
932+
933+
return base.rt2tr(R, t)
780934

781935
def trplot2(
782936
T,

0 commit comments

Comments
 (0)