Skip to content

Commit c89285a

Browse files
Merge pull request #83 from bdaiinstitute/add-copy-when-passed-as-input
[SW-235] Added a copy_from method to copy inputs instead of just passing by re…
2 parents 3a93de3 + 218201e commit c89285a

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

spatialmath/pose3d.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,6 +1934,27 @@ def Rt(
19341934
t = np.zeros((3,))
19351935
return cls(smb.rt2tr(R, t, check=check), check=check)
19361936

1937+
@classmethod
1938+
def CopyFrom(
1939+
cls,
1940+
T: SE3Array,
1941+
check: bool = True
1942+
) -> SE3:
1943+
"""
1944+
Create an SE(3) from a 4x4 numpy array that is passed by value.
1945+
1946+
:param T: homogeneous transformation
1947+
:type T: ndarray(4, 4)
1948+
:param check: check rotation validity, defaults to True
1949+
:type check: bool, optional
1950+
:raises ValueError: bad rotation matrix, bad transformation matrix
1951+
:return: SE(3) matrix representing that transformation
1952+
:rtype: SE3 instance
1953+
"""
1954+
if T is None:
1955+
raise ValueError("Transformation matrix must not be None")
1956+
return cls(np.copy(T), check=check)
1957+
19371958
def angdist(self, other: SE3, metric: int = 6) -> float:
19381959
r"""
19391960
Angular distance metric between poses

tests/test_pose3d.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,21 @@ def test_properties(self):
888888
nt.assert_equal(R.N, 3)
889889
nt.assert_equal(R.shape, (4, 4))
890890

891+
# Testing the CopyFrom function
892+
mutable_array = np.eye(4)
893+
pass_by_ref = SE3(mutable_array)
894+
pass_by_val = SE3.CopyFrom(mutable_array)
895+
mutable_array[0, 3] = 5.0
896+
nt.assert_allclose(pass_by_val.data[0], np.eye(4))
897+
nt.assert_allclose(pass_by_ref.data[0], mutable_array)
898+
nt.assert_raises(
899+
AssertionError,
900+
nt.assert_allclose,
901+
pass_by_val.data[0],
902+
pass_by_ref.data[0]
903+
)
904+
905+
891906
def test_arith(self):
892907
T = SE3(1, 2, 3)
893908

0 commit comments

Comments
 (0)