Skip to content

Commit ca1a2bd

Browse files
committed
Raise error on invalid rotation matrix or other array passed to
UnitQuaternion constructor.
1 parent a07e74d commit ca1a2bd

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

spatialmath/quaternion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -989,9 +989,12 @@ def __init__(
989989
# a quaternion as a 1D array
990990
# an array of quaternions as an nx4 array
991991

992-
if smb.isrot(s, check=check):
993-
# UnitQuaternion(R) R is 3x3 rotation matrix
994-
self.data = [smb.r2q(s)]
992+
if s.shape == (3, 3):
993+
if smb.isrot(s, check=check):
994+
# UnitQuaternion(R) R is 3x3 rotation matrix
995+
self.data = [smb.r2q(s)]
996+
else:
997+
raise ValueError("invalid rotation matrix provided to UnitQuaternion constructor")
995998
elif s.shape == (4,):
996999
# passed a 4-vector
9971000
if norm:
@@ -1004,6 +1007,8 @@ def __init__(
10041007
else:
10051008
# self.data = [smb.qpositive(x) for x in s]
10061009
self.data = [x for x in s]
1010+
else:
1011+
raise ValueError("array could not be interpreted as UnitQuaternion")
10071012

10081013
elif isinstance(s, SO3):
10091014
# UnitQuaternion(x) x is SO3 or SE3 (since SE3 is subclass of SO3)

tests/test_quaternion.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,23 @@ def test_constructor(self):
151151
q = UnitQuaternion(rotx(0.3))
152152
qcompare(UnitQuaternion(q), q)
153153

154+
# fail when invalid arrays are provided
155+
# invalid rotation matrix
156+
R = 1.1 * np.eye(3)
157+
with self.assertRaises(ValueError):
158+
UnitQuaternion(R, check=True)
159+
160+
# no check, so try to interpret as a quaternion, but shape is wrong
161+
with self.assertRaises(ValueError):
162+
UnitQuaternion(R, check=False)
163+
164+
# wrong shape to be anything
165+
R = np.zeros((5, 5))
166+
with self.assertRaises(ValueError):
167+
UnitQuaternion(R, check=True)
168+
with self.assertRaises(ValueError):
169+
UnitQuaternion(R, check=False)
170+
154171
def test_concat(self):
155172
u = UnitQuaternion()
156173
uu = UnitQuaternion([u, u, u, u])

0 commit comments

Comments
 (0)