Skip to content

Commit 256232d

Browse files
authored
ensure that angdiff returns a scalar for scalar args (#89)
1 parent e3b5507 commit 256232d

File tree

7 files changed

+120
-9
lines changed

7 files changed

+120
-9
lines changed

spatialmath/base/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@
216216
"isrot2",
217217
"trlog2",
218218
"trexp2",
219+
"trnorm2",
219220
"tr2jac2",
220221
"trinterp2",
221222
"trprint2",

spatialmath/base/transforms2d.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import spatialmath.base as smb
2929
from spatialmath.base.types import *
3030
from spatialmath.base.transformsNd import rt2tr
31+
from spatialmath.base.vectors import unitvec
3132

3233
_eps = np.finfo(np.float64).eps
3334

@@ -679,6 +680,72 @@ def trexp2(
679680
raise ValueError(" First argument must be SO(2), 1-vector, SE(2) or 3-vector")
680681

681682

683+
@overload # pragma: no cover
684+
def trnorm2(R: SO2Array) -> SO2Array:
685+
...
686+
687+
688+
def trnorm2(T: SE2Array) -> SE2Array:
689+
r"""
690+
Normalize an SO(2) or SE(2) matrix
691+
692+
:param T: SE(2) or SO(2) matrix
693+
:type T: ndarray(3,3) or ndarray(2,2)
694+
:return: normalized SE(2) or SO(2) matrix
695+
:rtype: ndarray(3,3) or ndarray(2,2)
696+
:raises ValueError: bad arguments
697+
698+
- ``trnorm(R)`` is guaranteed to be a proper orthogonal matrix rotation
699+
matrix (2,2) which is *close* to the input matrix R (2,2).
700+
- ``trnorm(T)`` as above but the rotational submatrix of the homogeneous
701+
transformation T (3,3) is normalised while the translational part is
702+
unchanged.
703+
704+
The steps in normalization are:
705+
706+
#. If :math:`\mathbf{R} = [a, b]`
707+
#. Form unit vectors :math:`\hat{b}
708+
#. Form the orthogonal planar vector :math:`\hat{a} = [\hat{b}_y -\hat{b}_x]`
709+
#. Form the normalized SO(2) matrix :math:`\mathbf{R} = [\hat{a}, \hat{b}]`
710+
711+
.. runblock:: pycon
712+
713+
>>> from spatialmath.base import trnorm, troty
714+
>>> from numpy import linalg
715+
>>> T = trot2(45, 'deg', t=[3, 4])
716+
>>> linalg.det(T[:2,:2]) - 1 # is a valid SO(3)
717+
>>> T = T @ T @ T @ T @ T @ T @ T @ T @ T @ T @ T @ T @ T
718+
>>> linalg.det(T[:2,:2]) - 1 # not quite a valid SE(2) anymore
719+
>>> T = trnorm2(T)
720+
>>> linalg.det(T[:2,:2]) - 1 # once more a valid SE(2)
721+
722+
.. note::
723+
724+
- Only the direction of a-vector (the z-axis) is unchanged.
725+
- Used to prevent finite word length arithmetic causing transforms to
726+
become 'unnormalized', ie. determinant :math:`\ne 1`.
727+
"""
728+
729+
if not ishom2(T) and not isrot2(T):
730+
raise ValueError("expecting SO(2) or SE(2)")
731+
732+
a = T[:, 0]
733+
b = T[:, 1]
734+
735+
b = unitvec(b)
736+
# fmt: off
737+
R = np.array([
738+
[ b[1], b[0]],
739+
[-b[0], b[1]]
740+
])
741+
# fmt: on
742+
743+
if ishom2(T):
744+
return rt2tr(cast(SO2Array, R), T[:2, 2])
745+
else:
746+
return R
747+
748+
682749
@overload # pragma: no cover
683750
def tradjoint2(T: SO2Array) -> R1x1:
684751
...

spatialmath/base/transforms3d.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,13 +1533,17 @@ def trexp(S, theta=None, check=True):
15331533
raise ValueError(" First argument must be SO(3), 3-vector, SE(3) or 6-vector")
15341534

15351535

1536+
@overload # pragma: no cover
1537+
def trnorm(R: SO3Array) -> SO3Array:
1538+
...
1539+
1540+
15361541
def trnorm(T: SE3Array) -> SE3Array:
15371542
r"""
15381543
Normalize an SO(3) or SE(3) matrix
15391544
1540-
:param R: SE(3) or SO(3) matrix
1541-
:type R: ndarray(4,4) or ndarray(3,3)
1542-
:param T1: second SE(3) matrix
1545+
:param T: SE(3) or SO(3) matrix
1546+
:type T: ndarray(4,4) or ndarray(3,3)
15431547
:return: normalized SE(3) or SO(3) matrix
15441548
:rtype: ndarray(4,4) or ndarray(3,3)
15451549
:raises ValueError: bad arguments
@@ -1565,9 +1569,9 @@ def trnorm(T: SE3Array) -> SE3Array:
15651569
>>> T = troty(45, 'deg', t=[3, 4, 5])
15661570
>>> linalg.det(T[:3,:3]) - 1 # is a valid SO(3)
15671571
>>> T = T @ T @ T @ T @ T @ T @ T @ T @ T @ T @ T @ T @ T
1568-
>>> linalg.det(T[:3,:3]) - 1 # not quite a valid SO(3) anymore
1572+
>>> linalg.det(T[:3,:3]) - 1 # not quite a valid SE(3) anymore
15691573
>>> T = trnorm(T)
1570-
>>> linalg.det(T[:3,:3]) - 1 # once more a valid SO(3)
1574+
>>> linalg.det(T[:3,:3]) - 1 # once more a valid SE(3)
15711575
15721576
.. note::
15731577

spatialmath/base/vectors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,11 @@ def angdiff(a, b=None):
647647
b = getvector(b)
648648
a = a - b # cannot use -= here, numpy wont broadcast
649649

650-
return np.mod(a + math.pi, 2 * math.pi) - math.pi
650+
y = np.mod(a + math.pi, 2 * math.pi) - math.pi
651+
if isinstance(y, np.ndarray) and len(y) == 1:
652+
return float(y)
653+
else:
654+
return y
651655

652656

653657
def angle_std(theta: ArrayLike) -> float:

tests/base/test_transforms2d.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,20 @@ def test_trexp2(self):
9494
T = transl2(1, 2) @ trot2(0.5)
9595
nt.assert_array_almost_equal(trexp2(logm(T)), T)
9696

97+
def test_trnorm2(self):
98+
R = rot2(0.4)
99+
R = np.round(R, 3) # approx SO(2)
100+
R = trnorm2(R)
101+
self.assertTrue(isrot2(R, check=True))
102+
103+
R = rot2(0.4)
104+
R = np.round(R, 3) # approx SO(2)
105+
T = rt2tr(R, [3, 4])
106+
107+
T = trnorm2(T)
108+
self.assertTrue(ishom2(T, check=True))
109+
nt.assert_almost_equal(T[:2, 2], [3, 4])
110+
97111
def test_transl2(self):
98112
nt.assert_array_almost_equal(
99113
transl2(1, 2), np.array([[1, 0, 1], [0, 1, 2], [0, 0, 1]])

tests/base/test_transforms3d.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,20 @@ def test_tr2rpy(self):
430430
a = rpy2tr(ang, order=seq)
431431
nt.assert_array_almost_equal(rpy2tr(tr2rpy(a, order=seq), order=seq), a)
432432

433+
def test_trnorm(self):
434+
R = rpy2r(0.2, 0.3, 0.4)
435+
R = np.round(R, 3) # approx SO(3)
436+
R = trnorm(R)
437+
self.assertTrue(isrot(R, check=True))
438+
439+
R = rpy2r(0.2, 0.3, 0.4)
440+
R = np.round(R, 3) # approx SO(3)
441+
T = rt2tr(R, [3, 4, 5])
442+
443+
T = trnorm(T)
444+
self.assertTrue(ishom(T, check=True))
445+
nt.assert_almost_equal(T[:3, 3], [3, 4, 5])
446+
433447
def test_tr2eul(self):
434448
eul = np.r_[0.1, 0.2, 0.3]
435449
R = eul2r(eul)

tests/base/test_vectors.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,16 +226,23 @@ def test_iszero(self):
226226

227227
def test_angdiff(self):
228228
self.assertEqual(angdiff(0, 0), 0)
229+
self.assertIsInstance(angdiff(0, 0), float)
229230
self.assertEqual(angdiff(pi, 0), -pi)
230231
self.assertEqual(angdiff(-pi, pi), 0)
231232

232-
nt.assert_array_almost_equal(angdiff([0, -pi, pi], 0), [0, -pi, -pi])
233+
x = angdiff([0, -pi, pi], 0)
234+
nt.assert_array_almost_equal(x, [0, -pi, -pi])
235+
self.assertIsInstance(x, np.ndarray)
233236
nt.assert_array_almost_equal(angdiff([0, -pi, pi], pi), [-pi, 0, 0])
234237

235-
nt.assert_array_almost_equal(angdiff(0, [0, -pi, pi]), [0, -pi, -pi])
238+
x = angdiff(0, [0, -pi, pi])
239+
nt.assert_array_almost_equal(x, [0, -pi, -pi])
240+
self.assertIsInstance(x, np.ndarray)
236241
nt.assert_array_almost_equal(angdiff(pi, [0, -pi, pi]), [-pi, 0, 0])
237242

238-
nt.assert_array_almost_equal(angdiff([1, 2, 3], [1, 2, 3]), [0, 0, 0])
243+
x = angdiff([1, 2, 3], [1, 2, 3])
244+
nt.assert_array_almost_equal(x, [0, 0, 0])
245+
self.assertIsInstance(x, np.ndarray)
239246

240247
def test_wrap(self):
241248
self.assertAlmostEqual(wrap_0_2pi(0), 0)

0 commit comments

Comments
 (0)