Skip to content

Commit 3dec5a8

Browse files
authored
Update more tolerance usages, add a few docstring and unit test (#106)
1 parent fd9e6a0 commit 3dec5a8

File tree

9 files changed

+191
-29
lines changed

9 files changed

+191
-29
lines changed

spatialmath/base/quaternions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def r2q(
761761

762762

763763
def qslerp(
764-
q0: ArrayLike4, q1: ArrayLike4, s: float, shortest: Optional[bool] = False
764+
q0: ArrayLike4, q1: ArrayLike4, s: float, shortest: Optional[bool] = False, tol: float = 10
765765
) -> UnitQuaternionArray:
766766
"""
767767
Quaternion conjugate
@@ -774,6 +774,8 @@ def qslerp(
774774
:type s: float
775775
:arg shortest: choose shortest distance [default False]
776776
:type shortest: bool
777+
:param tol: Tolerance when checking for identical quaternions, in multiples of eps, defaults to 10
778+
:type tol: float, optional
777779
:return: interpolated unit-quaternion
778780
:rtype: ndarray(4)
779781
:raises ValueError: s is outside interval [0, 1]
@@ -822,7 +824,7 @@ def qslerp(
822824

823825
dotprod = np.clip(dotprod, -1, 1) # Clip within domain of acos()
824826
theta = math.acos(dotprod) # theta is the angle between rotation vectors
825-
if abs(theta) > 10 * _eps:
827+
if abs(theta) > tol * _eps:
826828
s0 = math.sin((1 - s) * theta)
827829
s1 = math.sin(s * theta)
828830
return ((q0 * s0) + (q1 * s1)) / math.sin(theta)

spatialmath/base/transforms2d.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,15 @@ def tr2pos2(T):
266266
:return: translation elements of SE(2) matrix
267267
:rtype: ndarray(2)
268268
269-
- ``t = transl2(T)`` is the translational part of the SE(3) matrix ``T`` as a
269+
- ``t = tr2pos2(T)`` is the translational part of the SE(3) matrix ``T`` as a
270270
2-element NumPy array.
271271
272272
.. runblock:: pycon
273273
274274
>>> from spatialmath.base import *
275275
>>> import numpy as np
276276
>>> T = np.array([[1, 0, 3], [0, 1, 4], [0, 0, 1]])
277-
>>> transl2(T)
277+
>>> tr2pos2(T)
278278
279279
:seealso: :func:`pos2tr2` :func:`transl2`
280280
"""
@@ -292,19 +292,19 @@ def pos2tr2(x, y=None):
292292
:return: SE(2) matrix
293293
:rtype: ndarray(3,3)
294294
295-
- ``T = transl2([X, Y])`` is an SE(2) homogeneous transform (3x3)
295+
- ``T = pos2tr2([X, Y])`` is an SE(2) homogeneous transform (3x3)
296296
representing a pure translation.
297-
- ``T = transl2( V )`` as above but the translation is given by a 2-element
297+
- ``T = pos2tr2( V )`` as above but the translation is given by a 2-element
298298
list, dict, or a numpy array, row or column vector.
299299
300300
301301
.. runblock:: pycon
302302
303303
>>> from spatialmath.base import *
304304
>>> import numpy as np
305-
>>> transl2(3, 4)
306-
>>> transl2([3, 4])
307-
>>> transl2(np.array([3, 4]))
305+
>>> pos2tr2(3, 4)
306+
>>> pos2tr2([3, 4])
307+
>>> pos2tr2(np.array([3, 4]))
308308
309309
:seealso: :func:`tr2pos2` :func:`transl2`
310310
"""
@@ -1016,8 +1016,22 @@ def trprint2(
10161016
return s
10171017

10181018

1019-
def _vec2s(fmt: str, v: ArrayLikePure):
1020-
v = [x if np.abs(x) > 100 * _eps else 0.0 for x in v]
1019+
def _vec2s(fmt: str, v: ArrayLikePure, tol: float = 100) -> str:
1020+
"""
1021+
Return a string representation for vector using the provided fmt.
1022+
1023+
:param fmt: format string for each value in v
1024+
:type fmt: str
1025+
:param tol: Tolerance when checking for near-zero values, in multiples of eps, defaults to 100
1026+
:type tol: float, optional
1027+
:return: string representation for the vector
1028+
:rtype: str
1029+
1030+
Return a string representation for vector using the provided fmt, where
1031+
near-zero values are rounded to 0.
1032+
"""
1033+
1034+
v = [x if np.abs(x) > tol * _eps else 0.0 for x in v]
10211035
return ", ".join([fmt.format(x) for x in v])
10221036

10231037

spatialmath/base/transforms3d.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def eul2tr(
713713
# ---------------------------------------------------------------------------------------#
714714

715715

716-
def angvec2r(theta: float, v: ArrayLike3, unit="rad") -> SO3Array:
716+
def angvec2r(theta: float, v: ArrayLike3, unit="rad", tol: float = 10) -> SO3Array:
717717
"""
718718
Create an SO(3) rotation matrix from rotation angle and axis
719719
@@ -723,6 +723,8 @@ def angvec2r(theta: float, v: ArrayLike3, unit="rad") -> SO3Array:
723723
:type unit: str
724724
:param v: 3D rotation axis
725725
:type v: array_like(3)
726+
:param tol: Tolerance in units of eps for zero-rotation case, defaults to 10
727+
:type: float
726728
:return: SO(3) rotation matrix
727729
:rtype: ndarray(3,3)
728730
:raises ValueError: bad arguments
@@ -748,7 +750,7 @@ def angvec2r(theta: float, v: ArrayLike3, unit="rad") -> SO3Array:
748750
if not isscalar(theta) or not isvector(v, 3):
749751
raise ValueError("Arguments must be angle and vector")
750752

751-
if np.linalg.norm(v) < 10 * _eps:
753+
if np.linalg.norm(v) < tol * _eps:
752754
return np.eye(3)
753755

754756
θ = getunit(theta, unit)
@@ -1044,6 +1046,7 @@ def tr2eul(
10441046
unit: str = "rad",
10451047
flip: bool = False,
10461048
check: bool = False,
1049+
tol: float = 10,
10471050
) -> R3:
10481051
r"""
10491052
Convert SO(3) or SE(3) to ZYX Euler angles
@@ -1056,6 +1059,8 @@ def tr2eul(
10561059
:type flip: bool
10571060
:param check: check that rotation matrix is valid
10581061
:type check: bool
1062+
:param tol: Tolerance in units of eps for near-zero checks, defaults to 10
1063+
:type: float
10591064
:return: ZYZ Euler angles
10601065
:rtype: ndarray(3)
10611066
@@ -1090,11 +1095,11 @@ def tr2eul(
10901095
R = t2r(T)
10911096
else:
10921097
R = T
1093-
if not isrot(R, check=check):
1098+
if not isrot(R, check=check, tol=tol):
10941099
raise ValueError("argument is not SO(3)")
10951100

10961101
eul = np.zeros((3,))
1097-
if abs(R[0, 2]) < 10 * _eps and abs(R[1, 2]) < 10 * _eps:
1102+
if abs(R[0, 2]) < tol * _eps and abs(R[1, 2]) < tol * _eps:
10981103
eul[0] = 0
10991104
sp = 0
11001105
cp = 1
@@ -1124,6 +1129,7 @@ def tr2rpy(
11241129
unit: str = "rad",
11251130
order: str = "zyx",
11261131
check: bool = False,
1132+
tol: float = 10,
11271133
) -> R3:
11281134
r"""
11291135
Convert SO(3) or SE(3) to roll-pitch-yaw angles
@@ -1136,6 +1142,8 @@ def tr2rpy(
11361142
:type order: str
11371143
:param check: check that rotation matrix is valid
11381144
:type check: bool
1145+
:param tol: Tolerance in units of eps, defaults to 10
1146+
:type: float
11391147
:return: Roll-pitch-yaw angles
11401148
:rtype: ndarray(3)
11411149
:raises ValueError: bad arguments
@@ -1176,13 +1184,13 @@ def tr2rpy(
11761184
R = t2r(T)
11771185
else:
11781186
R = T
1179-
if not isrot(R, check=check):
1187+
if not isrot(R, check=check, tol=tol):
11801188
raise ValueError("not a valid SO(3) matrix")
11811189

11821190
rpy = np.zeros((3,))
11831191
if order in ("xyz", "arm"):
11841192
# XYZ order
1185-
if abs(abs(R[0, 2]) - 1) < 10 * _eps: # when |R13| == 1
1193+
if abs(abs(R[0, 2]) - 1) < tol * _eps: # when |R13| == 1
11861194
# singularity
11871195
rpy[0] = 0 # roll is zero
11881196
if R[0, 2] > 0:
@@ -1206,7 +1214,7 @@ def tr2rpy(
12061214

12071215
elif order in ("zyx", "vehicle"):
12081216
# old ZYX order (as per Paul book)
1209-
if abs(abs(R[2, 0]) - 1) < 10 * _eps: # when |R31| == 1
1217+
if abs(abs(R[2, 0]) - 1) < tol * _eps: # when |R31| == 1
12101218
# singularity
12111219
rpy[0] = 0 # roll is zero
12121220
if R[2, 0] < 0:
@@ -1229,7 +1237,7 @@ def tr2rpy(
12291237
rpy[1] = -math.atan(R[2, 0] * math.cos(rpy[0]) / R[2, 2])
12301238

12311239
elif order in ("yxz", "camera"):
1232-
if abs(abs(R[1, 2]) - 1) < 10 * _eps: # when |R23| == 1
1240+
if abs(abs(R[1, 2]) - 1) < tol * _eps: # when |R23| == 1
12331241
# singularity
12341242
rpy[0] = 0
12351243
if R[1, 2] < 0:

spatialmath/base/vectors.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def unittwist(S: ArrayLike6, tol: float = 10) -> Union[R6, None]:
392392
return S / th
393393

394394

395-
def unittwist_norm(S: Union[R6, ArrayLike6], tol: float = 10) -> Tuple[R6, float]:
395+
def unittwist_norm(S: Union[R6, ArrayLike6], tol: float = 10) -> Tuple[Union[R6, None], Union[float, None]]:
396396
"""
397397
Convert twist to unit twist and norm
398398
@@ -424,7 +424,7 @@ def unittwist_norm(S: Union[R6, ArrayLike6], tol: float = 10) -> Tuple[R6, float
424424
S = getvector(S, 6)
425425

426426
if iszerovec(S, tol=tol):
427-
raise ValueError("zero norm")
427+
return (None, None) # according to "note" in docstring.
428428

429429
v = S[0:3]
430430
w = S[3:6]
@@ -437,7 +437,7 @@ def unittwist_norm(S: Union[R6, ArrayLike6], tol: float = 10) -> Tuple[R6, float
437437
return (S / th, th)
438438

439439

440-
def unittwist2(S: ArrayLike3, tol: float = 10) -> R3:
440+
def unittwist2(S: ArrayLike3, tol: float = 10) -> Union[R3, None]:
441441
"""
442442
Convert twist to unit twist
443443
@@ -459,9 +459,14 @@ def unittwist2(S: ArrayLike3, tol: float = 10) -> R3:
459459
>>> unittwist2([2, 4, 2)
460460
>>> unittwist2([2, 0, 0])
461461
462+
.. note:: Returns None if the twist has zero magnitude
462463
"""
463464

464465
S = getvector(S, 3)
466+
467+
if iszerovec(S, tol=tol):
468+
return None
469+
465470
v = S[0:2]
466471
w = S[2]
467472

@@ -473,7 +478,7 @@ def unittwist2(S: ArrayLike3, tol: float = 10) -> R3:
473478
return S / th
474479

475480

476-
def unittwist2_norm(S: ArrayLike3, tol: float = 10) -> Tuple[R3, float]:
481+
def unittwist2_norm(S: ArrayLike3, tol: float = 10) -> Tuple[Union[R3, None], Union[float, None]]:
477482
"""
478483
Convert twist to unit twist
479484
@@ -495,9 +500,14 @@ def unittwist2_norm(S: ArrayLike3, tol: float = 10) -> Tuple[R3, float]:
495500
>>> unittwist2([2, 4, 2)
496501
>>> unittwist2([2, 0, 0])
497502
503+
.. note:: Returns (None, None) if the twist has zero magnitude
498504
"""
499505

500506
S = getvector(S, 3)
507+
508+
if iszerovec(S, tol=tol):
509+
return (None, None)
510+
501511
v = S[0:2]
502512
w = S[2]
503513

@@ -728,7 +738,7 @@ def angle_wrap(theta: ArrayLike, mode: str = "-pi:pi") -> Union[float, NDArray]:
728738
return wrap_mpi_pi(theta)
729739
elif mode == "0:pi":
730740
return wrap_0_pi(theta)
731-
elif mode == "0:pi":
741+
elif mode == "-pi/2:pi/2":
732742
return wrap_mpi2_pi2(theta)
733743
else:
734744
raise ValueError("bad method specified")

spatialmath/geom3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def isintersecting(
712712
713713
:seealso: :meth:`__xor__` :meth:`intersects` :meth:`isparallel`
714714
"""
715-
return not l1.isparallel(l2, tol=tol) and bool(abs(l1 * l2) < 10 * _eps)
715+
return not l1.isparallel(l2, tol=tol) and bool(abs(l1 * l2) < tol * _eps)
716716

717717
def __eq__(l1, l2: Line3) -> bool: # type: ignore pylint: disable=no-self-argument
718718
"""

spatialmath/quaternion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,12 @@ def log(self) -> Quaternion:
396396
v = math.acos(self.s / norm) * smb.unitvec(self.v)
397397
return Quaternion(s=s, v=v)
398398

399-
def exp(self) -> Quaternion:
399+
def exp(self, tol: float=100) -> Quaternion:
400400
r"""
401401
Exponential of quaternion
402402
403+
:param tol: Tolerance when checking for pure quaternion, in multiples of eps, defaults to 100
404+
:type tol: float, optional
403405
:rtype: Quaternion instance
404406
405407
``q.exp()`` is the exponential of the quaternion ``q``, ie.
@@ -433,7 +435,7 @@ def exp(self) -> Quaternion:
433435
norm_v = smb.norm(self.v)
434436
s = exp_s * math.cos(norm_v)
435437
v = exp_s * self.v / norm_v * math.sin(norm_v)
436-
if abs(self.s) < 100 * _eps:
438+
if abs(self.s) < tol * _eps:
437439
# result will be a unit quaternion
438440
return UnitQuaternion(s=s, v=v)
439441
else:

tests/base/test_transforms2d.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,30 @@ def test_transl2(self):
118118
transl2([1, 2]), np.array([[1, 0, 1], [0, 1, 2], [0, 0, 1]])
119119
)
120120

121+
def test_pos2tr2(self):
122+
nt.assert_array_almost_equal(
123+
pos2tr2(1, 2), np.array([[1, 0, 1], [0, 1, 2], [0, 0, 1]])
124+
)
125+
nt.assert_array_almost_equal(
126+
transl2([1, 2]), np.array([[1, 0, 1], [0, 1, 2], [0, 0, 1]])
127+
)
128+
nt.assert_array_almost_equal(
129+
tr2pos2(pos2tr2(1, 2)), np.array([1, 2])
130+
)
131+
132+
def test_tr2jac2(self):
133+
T = trot2(0.3, t=[4, 5])
134+
jac2 = tr2jac2(T)
135+
nt.assert_array_almost_equal(
136+
jac2[:2, :2], smb.t2r(T)
137+
)
138+
nt.assert_array_almost_equal(
139+
jac2[:3, 2], np.array([0, 0, 1])
140+
)
141+
nt.assert_array_almost_equal(
142+
jac2[2, :3], np.array([0, 0, 1])
143+
)
144+
121145
def test_xyt2tr(self):
122146
T = xyt2tr([1, 2, 0])
123147
nt.assert_array_almost_equal(T, transl2(1, 2))

0 commit comments

Comments
 (0)