Skip to content

Commit 39d6127

Browse files
authored
Fix for issue #144. (#155)
1 parent 659ba24 commit 39d6127

File tree

6 files changed

+109
-40
lines changed

6 files changed

+109
-40
lines changed

spatialmath/base/argcheck.py

+34-16
Original file line numberDiff line numberDiff line change
@@ -522,16 +522,20 @@ def isvector(v: Any, dim: Optional[int] = None) -> bool:
522522
return False
523523

524524

525-
def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
525+
def getunit(
526+
v: ArrayLike, unit: str = "rad", dim: Optional[int] = None, vector: bool = True
527+
) -> Union[float, NDArray]:
526528
"""
527529
Convert values according to angular units
528530
529531
:param v: the value in radians or degrees
530532
:type v: array_like(m)
531533
:param unit: the angular unit, "rad" or "deg"
532534
:type unit: str
533-
:param dim: expected dimension of input, defaults to None
535+
:param dim: expected dimension of input, defaults to don't check (None)
534536
:type dim: int, optional
537+
:param vector: return a scalar as a 1d vector, defaults to True
538+
:type vector: bool, optional
535539
:return: the converted value in radians
536540
:rtype: ndarray(m) or float
537541
:raises ValueError: argument is not a valid angular unit
@@ -543,30 +547,44 @@ def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
543547
>>> from spatialmath.base import getunit
544548
>>> import numpy as np
545549
>>> getunit(1.5, 'rad')
546-
>>> getunit(1.5, 'rad', dim=0)
547-
>>> # getunit([1.5], 'rad', dim=0) --> ValueError
548550
>>> getunit(90, 'deg')
551+
>>> getunit(90, 'deg', vector=False) # force a scalar output
552+
>>> getunit(1.5, 'rad', dim=0) # check argument is scalar
553+
>>> getunit(1.5, 'rad', dim=3) # check argument is a 3-vector
554+
>>> getunit([1.5], 'rad', dim=1) # check argument is a 1-vector
555+
>>> getunit([1.5], 'rad', dim=3) # check argument is a 3-vector
549556
>>> getunit([90, 180], 'deg')
550-
>>> getunit(np.r_[0.5, 1], 'rad')
551557
>>> getunit(np.r_[90, 180], 'deg')
552-
>>> getunit(np.r_[90, 180], 'deg', dim=2)
553-
>>> # getunit([90, 180], 'deg', dim=3) --> ValueError
558+
>>> getunit(np.r_[90, 180], 'deg', dim=2) # check argument is a 2-vector
559+
>>> getunit([90, 180], 'deg', dim=3) # check argument is a 3-vector
554560
555561
:note:
556562
- the input value is processed by :func:`getvector` and the argument ``dim`` can
557-
be used to check that ``v`` is the desired length.
558-
- the output is always an ndarray except if the input is a scalar and ``dim=0``.
563+
be used to check that ``v`` is the desired length. Note that 0 means a scalar,
564+
whereas 1 means a 1-element array.
565+
- the output is always an ndarray except if the input is a scalar and ``vector=False``.
559566
560567
:seealso: :func:`getvector`
561568
"""
562-
if not isinstance(v, Iterable) and dim == 0:
563-
# scalar in, scalar out
564-
if unit == "rad":
565-
return v
566-
elif unit == "deg":
567-
return np.deg2rad(v)
569+
if not isinstance(v, Iterable):
570+
# scalar input
571+
if dim is not None and dim != 0:
572+
raise ValueError("for dim==0 input must be a scalar")
573+
if vector:
574+
# scalar in, vector out
575+
if unit == "deg":
576+
v = np.deg2rad(v)
577+
elif unit != "rad":
578+
raise ValueError("invalid angular units")
579+
return np.array([v])
568580
else:
569-
raise ValueError("invalid angular units")
581+
# scalar in, scalar out
582+
if unit == "rad":
583+
return v
584+
elif unit == "deg":
585+
return np.deg2rad(v)
586+
else:
587+
raise ValueError("invalid angular units")
570588

571589
else:
572590
# scalar or iterable in, ndarray out

spatialmath/base/transforms2d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def rot2(theta: float, unit: str = "rad") -> SO2Array:
6363
>>> rot2(0.3)
6464
>>> rot2(45, 'deg')
6565
"""
66-
theta = smb.getunit(theta, unit, dim=0)
66+
theta = smb.getunit(theta, unit, vector=False)
6767
ct = smb.sym.cos(theta)
6868
st = smb.sym.sin(theta)
6969
# fmt: off

spatialmath/base/transforms3d.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def rotx(theta: float, unit: str = "rad") -> SO3Array:
7979
:SymPy: supported
8080
"""
8181

82-
theta = getunit(theta, unit, dim=0)
82+
theta = getunit(theta, unit, vector=False)
8383
ct = sym.cos(theta)
8484
st = sym.sin(theta)
8585
# fmt: off
@@ -118,7 +118,7 @@ def roty(theta: float, unit: str = "rad") -> SO3Array:
118118
:SymPy: supported
119119
"""
120120

121-
theta = getunit(theta, unit, dim=0)
121+
theta = getunit(theta, unit, vector=False)
122122
ct = sym.cos(theta)
123123
st = sym.sin(theta)
124124
# fmt: off
@@ -152,7 +152,7 @@ def rotz(theta: float, unit: str = "rad") -> SO3Array:
152152
:seealso: :func:`~trotz`
153153
:SymPy: supported
154154
"""
155-
theta = getunit(theta, unit, dim=0)
155+
theta = getunit(theta, unit, vector=False)
156156
ct = sym.cos(theta)
157157
st = sym.sin(theta)
158158
# fmt: off
@@ -2709,7 +2709,7 @@ def tr2adjoint(T):
27092709
27102710
:Reference:
27112711
- Robotics, Vision & Control for Python, Section 3, P. Corke, Springer 2023.
2712-
- `Lie groups for 2D and 3D Transformations <http://ethaneade.com/lie.pdf>_
2712+
- `Lie groups for 2D and 3D Transformations <http://ethaneade.com/lie.pdf>`_
27132713
27142714
:SymPy: supported
27152715
"""
@@ -3002,29 +3002,36 @@ def trplot(
30023002
- ``width`` of line
30033003
- ``length`` of line
30043004
- ``style`` which is one of:
3005+
30053006
- ``'arrow'`` [default], draw line with arrow head in ``color``
30063007
- ``'line'``, draw line with no arrow head in ``color``
30073008
- ``'rgb'``, frame axes are lines with no arrow head and red for X, green
3008-
for Y, blue for Z; no origin dot
3009+
for Y, blue for Z; no origin dot
30093010
- ``'rviz'``, frame axes are thick lines with no arrow head and red for X,
3010-
green for Y, blue for Z; no origin dot
3011+
green for Y, blue for Z; no origin dot
3012+
30113013
- coordinate axis labels depend on:
3014+
30123015
- ``axislabel`` if True [default] label the axis, default labels are X, Y, Z
30133016
- ``labels`` 3-list of alternative axis labels
30143017
- ``textcolor`` which defaults to ``color``
30153018
- ``axissubscript`` if True [default] add the frame label ``frame`` as a subscript
3016-
for each axis label
3019+
for each axis label
3020+
30173021
- coordinate frame label depends on:
3022+
30183023
- `frame` the label placed inside {} near the origin of the frame
3024+
30193025
- a dot at the origin
3026+
30203027
- ``originsize`` size of the dot, if zero no dot
30213028
- ``origincolor`` color of the dot, defaults to ``color``
30223029
30233030
Examples::
30243031
3025-
trplot(T, frame='A')
3026-
trplot(T, frame='A', color='green')
3027-
trplot(T1, 'labels', 'UVW');
3032+
trplot(T, frame='A')
3033+
trplot(T, frame='A', color='green')
3034+
trplot(T1, 'labels', 'UVW');
30283035
30293036
.. plot::
30303037
@@ -3383,12 +3390,12 @@ def tranimate(T: Union[SO3Array, SE3Array], **kwargs) -> str:
33833390
:param **kwargs: arguments passed to ``trplot``
33843391
33853392
- ``tranimate(T)`` where ``T`` is an SO(3) or SE(3) matrix, animates a 3D
3386-
coordinate frame moving from the world frame to the frame ``T`` in
3387-
``nsteps``.
3393+
coordinate frame moving from the world frame to the frame ``T`` in
3394+
``nsteps``.
33883395
33893396
- ``tranimate(I)`` where ``I`` is an iterable or generator, animates a 3D
3390-
coordinate frame representing the pose of each element in the sequence of
3391-
SO(3) or SE(3) matrices.
3397+
coordinate frame representing the pose of each element in the sequence of
3398+
SO(3) or SE(3) matrices.
33923399
33933400
Examples:
33943401

spatialmath/base/vectors.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -530,14 +530,15 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]:
530530
:param theta: input angle
531531
:type theta: scalar or ndarray
532532
:return: angle wrapped into range :math:`[0, \pi)`
533+
:rtype: scalar or ndarray
533534
534535
This is used to fold angles of colatitude. If zero is the angle of the
535536
north pole, colatitude increases to :math:`\pi` at the south pole then
536537
decreases to :math:`0` as we head back to the north pole.
537538
538539
:seealso: :func:`wrap_mpi2_pi2` :func:`wrap_0_2pi` :func:`wrap_mpi_pi` :func:`angle_wrap`
539540
"""
540-
theta = np.abs(theta)
541+
theta = np.abs(getvector(theta))
541542
n = theta / np.pi
542543
if isinstance(n, np.ndarray):
543544
n = n.astype(int)
@@ -546,7 +547,7 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]:
546547

547548
y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, (n + 1) * np.pi - theta)
548549
if isinstance(y, np.ndarray) and y.size == 1:
549-
return float(y)
550+
return float(y[0])
550551
else:
551552
return y
552553

@@ -558,6 +559,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> Union[float, NDArray]:
558559
:param theta: input angle
559560
:type theta: scalar or ndarray
560561
:return: angle wrapped into range :math:`[-\pi/2, \pi/2]`
562+
:rtype: scalar or ndarray
561563
562564
This is used to fold angles of latitude.
563565
@@ -573,7 +575,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> Union[float, NDArray]:
573575

574576
y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, n * np.pi - theta)
575577
if isinstance(y, np.ndarray) and len(y) == 1:
576-
return float(y)
578+
return float(y[0])
577579
else:
578580
return y
579581

@@ -585,13 +587,14 @@ def wrap_0_2pi(theta: ArrayLike) -> Union[float, NDArray]:
585587
:param theta: input angle
586588
:type theta: scalar or ndarray
587589
:return: angle wrapped into range :math:`[0, 2\pi)`
590+
:rtype: scalar or ndarray
588591
589592
:seealso: :func:`wrap_mpi_pi` :func:`wrap_0_pi` :func:`wrap_mpi2_pi2` :func:`angle_wrap`
590593
"""
591594
theta = getvector(theta)
592595
y = theta - 2.0 * math.pi * np.floor(theta / 2.0 / np.pi)
593596
if isinstance(y, np.ndarray) and len(y) == 1:
594-
return float(y)
597+
return float(y[0])
595598
else:
596599
return y
597600

@@ -603,13 +606,14 @@ def wrap_mpi_pi(theta: ArrayLike) -> Union[float, NDArray]:
603606
:param theta: input angle
604607
:type theta: scalar or ndarray
605608
:return: angle wrapped into range :math:`[-\pi, \pi)`
609+
:rtype: scalar or ndarray
606610
607611
:seealso: :func:`wrap_0_2pi` :func:`wrap_0_pi` :func:`wrap_mpi2_pi2` :func:`angle_wrap`
608612
"""
609613
theta = getvector(theta)
610614
y = np.mod(theta + math.pi, 2 * math.pi) - np.pi
611615
if isinstance(y, np.ndarray) and len(y) == 1:
612-
return float(y)
616+
return float(y[0])
613617
else:
614618
return y
615619

@@ -643,6 +647,7 @@ def angdiff(a, b=None):
643647
- ``angdiff(a, b)`` is the difference ``a - b`` wrapped to the range
644648
:math:`[-\pi, \pi)`. This is the operator :math:`a \circleddash b` used
645649
in the RVC book
650+
646651
- If ``a`` and ``b`` are both scalars, the result is scalar
647652
- If ``a`` is array_like, the result is a NumPy array ``a[i]-b``
648653
- If ``a`` is array_like, the result is a NumPy array ``a-b[i]``
@@ -651,6 +656,7 @@ def angdiff(a, b=None):
651656
652657
- ``angdiff(a)`` is the angle or vector of angles ``a`` wrapped to the range
653658
:math:`[-\pi, \pi)`.
659+
654660
- If ``a`` is a scalar, the result is scalar
655661
- If ``a`` is array_like, the result is a NumPy array
656662
@@ -671,7 +677,7 @@ def angdiff(a, b=None):
671677

672678
y = np.mod(a + math.pi, 2 * math.pi) - math.pi
673679
if isinstance(y, np.ndarray) and len(y) == 1:
674-
return float(y)
680+
return float(y[0])
675681
else:
676682
return y
677683

spatialmath/quaternion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1443,7 +1443,7 @@ def AngVec(
14431443
:seealso: :meth:`UnitQuaternion.angvec` :meth:`UnitQuaternion.exp` :func:`~spatialmath.base.transforms3d.angvec2r`
14441444
"""
14451445
v = smb.getvector(v, 3)
1446-
theta = smb.getunit(theta, unit, dim=0)
1446+
theta = smb.getunit(theta, unit, vector=False)
14471447
return cls(
14481448
s=math.cos(theta / 2), v=math.sin(theta / 2) * v, norm=False, check=False
14491449
)

tests/base/test_argcheck.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,49 @@ def test_verifymatrix(self):
122122
verifymatrix(a, (3, 4))
123123

124124
def test_unit(self):
125-
self.assertIsInstance(getunit(1), np.ndarray)
125+
# scalar -> vector
126+
self.assertEqual(getunit(1), np.array([1]))
127+
self.assertEqual(getunit(1, dim=0), np.array([1]))
128+
with self.assertRaises(ValueError):
129+
self.assertEqual(getunit(1, dim=1), np.array([1]))
130+
131+
self.assertEqual(getunit(1, unit="deg"), np.array([1 * math.pi / 180.0]))
132+
self.assertEqual(getunit(1, dim=0, unit="deg"), np.array([1 * math.pi / 180.0]))
133+
with self.assertRaises(ValueError):
134+
self.assertEqual(
135+
getunit(1, dim=1, unit="deg"), np.array([1 * math.pi / 180.0])
136+
)
137+
138+
# scalar -> scalar
139+
self.assertEqual(getunit(1, vector=False), 1)
140+
self.assertEqual(getunit(1, dim=0, vector=False), 1)
141+
with self.assertRaises(ValueError):
142+
self.assertEqual(getunit(1, dim=1, vector=False), 1)
143+
144+
self.assertIsInstance(getunit(1.0, vector=False), float)
145+
self.assertIsInstance(getunit(1, vector=False), int)
146+
147+
self.assertEqual(getunit(1, vector=False, unit="deg"), 1 * math.pi / 180.0)
148+
self.assertEqual(
149+
getunit(1, dim=0, vector=False, unit="deg"), 1 * math.pi / 180.0
150+
)
151+
with self.assertRaises(ValueError):
152+
self.assertEqual(
153+
getunit(1, dim=1, vector=False, unit="deg"), 1 * math.pi / 180.0
154+
)
155+
156+
self.assertIsInstance(getunit(1.0, vector=False, unit="deg"), float)
157+
self.assertIsInstance(getunit(1, vector=False, unit="deg"), float)
158+
159+
# vector -> vector
160+
self.assertEqual(getunit([1]), np.array([1]))
161+
self.assertEqual(getunit([1], dim=1), np.array([1]))
162+
with self.assertRaises(ValueError):
163+
getunit([1], dim=0)
164+
126165
self.assertIsInstance(getunit([1, 2]), np.ndarray)
127166
self.assertIsInstance(getunit((1, 2)), np.ndarray)
128167
self.assertIsInstance(getunit(np.r_[1, 2]), np.ndarray)
129-
self.assertIsInstance(getunit(1.0, dim=0), float)
130168

131169
nt.assert_equal(getunit(5, "rad"), 5)
132170
nt.assert_equal(getunit(5, "deg"), 5 * math.pi / 180.0)

0 commit comments

Comments
 (0)