Skip to content

Commit da8c794

Browse files
committed
ENH: Add displacement field interp
1 parent 6d022fd commit da8c794

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

mne/tests/test_transforms.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
_quat_to_affine,
5353
_compute_r2,
5454
_validate_pipeline,
55+
_MatchedDisplacementFieldInterpolator,
5556
)
5657

5758
data_path = testing.data_path(download=False)
@@ -630,3 +631,19 @@ def test_volume_registration():
630631
],
631632
atol=0.001,
632633
)
634+
635+
636+
def test_displacement_field():
637+
"""Test that our matched point deformation works."""
638+
to = np.array([[5, 4, 1], [6, 1, 0], [4, -1, 1], [3, 3, 0]], float)
639+
fro = np.array([[0, 2, 2], [2, 2, 1], [2, 0, 2], [0, 0, 1]], float)
640+
interp = _MatchedDisplacementFieldInterpolator(fro, to)
641+
fro_t = interp(fro)
642+
assert_allclose(to, fro_t, atol=1e-12)
643+
# check midpoints (should all be decent)
644+
for a in range(len(to)):
645+
for b in range(a + 1, len(to)):
646+
to_ = np.mean(to[[a, b]], axis=0)
647+
fro_ = np.mean(fro[[a, b]], axis=0)
648+
fro_t = interp(fro_)
649+
assert_allclose(to_, fro_t, atol=1e-12)

mne/transforms.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2073,3 +2073,47 @@ def apply_volume_registration_points(
20732073
info2.set_montage(montage2) # converts to head coordinates
20742074

20752075
return info2, trans2
2076+
2077+
2078+
class _MatchedDisplacementFieldInterpolator:
2079+
"""Interpolate from matched points using a displacement field in ND.
2080+
2081+
For a demo, see
2082+
https://gist.github.com/larsoner/fbe32d57996848395854d5e59dff1e10
2083+
and related tests.
2084+
"""
2085+
2086+
def __init__(self, fro, to):
2087+
from scipy.interpolate import LinearNDInterpolator
2088+
2089+
fro = np.array(fro, float)
2090+
to = np.array(to, float)
2091+
assert fro.shape == to.shape
2092+
assert fro.ndim == 2
2093+
# this restriction is only necessary because it's what
2094+
# _fit_matched_points requires
2095+
assert fro.shape[1] == 3
2096+
2097+
# Prealign using affine + uniform scaling
2098+
trans, scale = _fit_matched_points(fro, to, scale=True)
2099+
trans = _quat_to_affine(trans)
2100+
trans[:3, :3] *= scale
2101+
self._affine = trans
2102+
fro = apply_trans(trans, fro)
2103+
2104+
# Add points at extrema
2105+
delta = (to.max(axis=0) - to.min(axis=0)) / 2.0
2106+
extrema = np.array([fro.min(axis=0) - delta, fro.max(axis=0) + delta])
2107+
self._extrema = np.array(np.meshgrid(*extrema.T)).T.reshape(-1, fro.shape[-1])
2108+
fro_concat = np.concatenate((fro, self._extrema))
2109+
to_concat = np.concatenate((to, self._extrema))
2110+
2111+
# Compute the interpolator (which internally uses Delaunay)
2112+
self._interp = LinearNDInterpolator(fro_concat, to_concat)
2113+
2114+
def __call__(self, x):
2115+
assert x.ndim in (1, 2) and x.shape[-1] == 3
2116+
singleton = x.ndim == 1
2117+
out = self._interp(apply_trans(self._affine, x))
2118+
out = out[0] if singleton else out
2119+
return out

0 commit comments

Comments
 (0)