Skip to content

Commit d9dcf11

Browse files
BUG: Fix bug with simulating head pos and BEM (#13276)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent 40b9232 commit d9dcf11

File tree

4 files changed

+46
-15
lines changed

4 files changed

+46
-15
lines changed

doc/changes/devel/13276.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug with :func:`mne.simulation.simulate_raw` and :class:`mne.simulation.SourceSimulator` where using different head positions with ``head_pos`` and a BEM would raise an error, by `Eric Larson`_.

mne/chpi.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,9 @@ def read_head_pos(fname):
9999
100100
Returns
101101
-------
102-
pos : array, shape (N, 10)
102+
quats : array, shape (n_pos, 10)
103103
The position and quaternion parameters from cHPI fitting.
104+
See :func:`mne.chpi.compute_head_pos` for details on the columns.
104105
105106
See Also
106107
--------
@@ -126,8 +127,9 @@ def write_head_pos(fname, pos):
126127
----------
127128
fname : path-like
128129
The filename to write.
129-
pos : array, shape (N, 10)
130+
pos : array, shape (n_pos, 10)
130131
The position and quaternion parameters from cHPI fitting.
132+
See :func:`mne.chpi.compute_head_pos` for details on the columns.
131133
132134
See Also
133135
--------
@@ -141,7 +143,9 @@ def write_head_pos(fname, pos):
141143
_check_fname(fname, overwrite=True)
142144
pos = np.array(pos, np.float64)
143145
if pos.ndim != 2 or pos.shape[1] != 10:
144-
raise ValueError("pos must be a 2D array of shape (N, 10)")
146+
raise ValueError(
147+
f"pos must be a 2D array of shape (N, 10), got shape {pos.shape}"
148+
)
145149
with open(fname, "wb") as fid:
146150
fid.write(
147151
" Time q1 q2 q3 q4 q5 "
@@ -157,16 +161,17 @@ def head_pos_to_trans_rot_t(quats):
157161
158162
Parameters
159163
----------
160-
quats : ndarray, shape (N, 10)
164+
quats : ndarray, shape (n_pos, 10)
161165
MaxFilter-formatted position and quaternion parameters.
166+
See :func:`mne.chpi.read_head_pos` for details on the columns.
162167
163168
Returns
164169
-------
165-
translation : ndarray, shape (N, 3)
170+
translation : ndarray, shape (n_pos, 3)
166171
Translations at each time point.
167-
rotation : ndarray, shape (N, 3, 3)
172+
rotation : ndarray, shape (n_pos, 3, 3)
168173
Rotations at each time point.
169-
t : ndarray, shape (N,)
174+
t : ndarray, shape (n_pos,)
170175
The time points.
171176
172177
See Also
@@ -929,7 +934,8 @@ def compute_head_pos(
929934
Returns
930935
-------
931936
quats : ndarray, shape (n_pos, 10)
932-
The ``[t, q1, q2, q3, x, y, z, gof, err, v]`` for each fit.
937+
MaxFilter-formatted head position parameters. The columns correspond to
938+
``[t, q1, q2, q3, x, y, z, gof, err, v]`` for each time point.
933939
934940
See Also
935941
--------

mne/forward/_compute_forward.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def _dup_coil_set(coils, coord_frame, t):
3535
if t is not None:
3636
coord_frame = t["to"]
3737
for coil in coils:
38+
assert isinstance(coil, dict), f"Coil must be a dict, got {type(coil)}"
3839
for key in ("ex", "ey", "ez"):
3940
if key in coil:
4041
coil[key] = apply_trans(t["trans"], coil[key], False)
@@ -794,6 +795,7 @@ def _compute_forwards_meeg(rr, *, sensors, fwd_data, n_jobs, silent=False):
794795
mri_Q, bem_rr, fun = fwd_data["mri_Q"], fwd_data["bem_rr"], fwd_data["fun"]
795796
solutions = fwd_data["solutions"]
796797
del fwd_data
798+
rr = np.ascontiguousarray(rr) # usually true but not guaranteed, e.g. in dipole.py
797799
for coil_type, sens in sensors.items():
798800
coils = sens["defs"]
799801
compensator = sens.get("compensator", None)
@@ -835,6 +837,9 @@ def _compute_forwards(rr, *, bem, sensors, n_jobs, verbose=None):
835837
solver = bem.get("solver", "mne")
836838
_check_option("solver", solver, ("mne", "openmeeg"))
837839
if bem["is_sphere"] or solver == "mne":
840+
# This modifies "sensors" in place, so let's copy it in case the calling
841+
# function needs to reuse it (e.g., in simulate_raw.py)
842+
sensors = deepcopy(sensors)
838843
fwd_data = _prep_field_computation(rr, sensors=sensors, bem=bem, n_jobs=n_jobs)
839844
Bs = _compute_forwards_meeg(
840845
rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs

mne/simulation/tests/test_raw.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
compute_head_pos,
3939
get_chpi_info,
4040
read_head_pos,
41+
write_head_pos,
4142
)
4243
from mne.datasets import testing
4344
from mne.io import RawArray, read_raw_fif
@@ -54,6 +55,7 @@
5455
from mne.source_space._source_space import _compare_source_spaces
5556
from mne.surface import _get_ico_surface
5657
from mne.tests.test_chpi import _assert_quats
58+
from mne.transforms import _affine_to_quat
5759
from mne.utils import catch_logging
5860

5961
raw_fname_short = Path(__file__).parents[2] / "io" / "tests" / "data" / "test_raw.fif"
@@ -251,6 +253,11 @@ def test_simulate_raw_sphere(raw_data, tmp_path):
251253

252254
# head pos
253255
head_pos_sim = _get_head_pos_sim(raw)
256+
head_pos_sim_2 = np.zeros((len(head_pos_sim), 10))
257+
for ii, (t, mat) in enumerate(head_pos_sim.items()):
258+
head_pos_sim_2[ii, :7] = [t] + list(_affine_to_quat(mat))
259+
head_pos_sim_3 = tmp_path / "head_pos.txt"
260+
write_head_pos(head_pos_sim_3, head_pos_sim_2)
254261

255262
#
256263
# Test raw simulation with basic parameters
@@ -259,11 +266,9 @@ def test_simulate_raw_sphere(raw_data, tmp_path):
259266
cov = read_cov(cov_fname)
260267
cov["projs"] = raw.info["projs"]
261268
raw.info["bads"] = raw.ch_names[:1]
262-
sphere_norad = make_sphere_model("auto", None, raw.info)
263269
raw_meg = raw.copy().pick("meg")
264-
raw_sim = simulate_raw(
265-
raw_meg.info, stc, trans, src, sphere_norad, head_pos=head_pos_sim
266-
)
270+
raw_sim = simulate_raw(raw_meg.info, stc, trans, src, sphere, head_pos=head_pos_sim)
271+
raw_data = raw_sim[:][0]
267272
# Test IO on processed data
268273
test_outname = tmp_path / "sim_test_raw.fif"
269274
raw_sim.save(test_outname)
@@ -307,12 +312,14 @@ def test_simulate_raw_sphere(raw_data, tmp_path):
307312
)
308313
del raw_sim, raw_sim_2
309314

310-
# check that different interpolations are similar given small movements
315+
# check that different interpolations are similar given small movements,
316+
# using different input forms of head_pos
311317
raw_sim = simulate_raw(
312318
raw.info, stc, trans, src, sphere, head_pos=head_pos_sim, interp="linear"
313319
)
320+
assert_allclose(raw_sim.get_data("meg"), raw_data, rtol=0.02)
314321
raw_sim_hann = simulate_raw(
315-
raw.info, stc, trans, src, sphere, head_pos=head_pos_sim, interp="hann"
322+
raw.info, stc, trans, src, sphere, head_pos=head_pos_sim_3, interp="hann"
316323
)
317324
assert_allclose(raw_sim[:][0], raw_sim_hann[:][0], rtol=1e-1, atol=1e-14)
318325
del raw_sim_hann
@@ -391,7 +398,7 @@ def test_simulate_raw_bem(raw_data):
391398
cov = make_ad_hoc_cov(raw.info)
392399
# The tolerance for the BEM is surprisingly high but I get the same
393400
# result when using MNE-C and Xfit, even when using a proper 5120 BEM :(
394-
for use_raw, bem, tol in ((raw_sim_sph, sphere, 4), (raw_sim_bem, bem_fname, 31)):
401+
for use_raw, bem, tol in ((raw_sim_sph, sphere, 6), (raw_sim_bem, bem_fname, 31)):
395402
events = find_events(use_raw, "STI 014")
396403
assert len(locs) == 6
397404
evoked = Epochs(use_raw, events, 1, 0, tmax, baseline=None).average()
@@ -425,6 +432,18 @@ def test_simulate_raw_bem(raw_data):
425432
assert_allclose(amp0 / amp1, wf_sim[0] / wf_sim[1], rtol=1e-5)
426433
assert amp2 == 0
427434
assert raw_sim.n_times == ss.n_times
435+
# smoke test that different head positions can be used as well
436+
head_pos_sim = {1.0: raw.info["dev_head_t"]["trans"]}
437+
raw_sim_2 = simulate_raw(
438+
raw.info,
439+
ss,
440+
src=src_ss,
441+
bem=bem_fname,
442+
first_samp=first_samp,
443+
head_pos=head_pos_sim,
444+
)
445+
data_2 = raw_sim_2.get_data()
446+
assert_allclose(data, data_2, rtol=1e-7)
428447

429448

430449
@pytest.mark.slowtest # slow on Windows Azure

0 commit comments

Comments
 (0)