Skip to content

Commit 704b393

Browse files
BUG: Fix bug with least-squares sphere fit (#13178)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent e04c4d9 commit 704b393

File tree

12 files changed

+54
-69
lines changed

12 files changed

+54
-69
lines changed

doc/changes/devel/13178.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug with least-squares fitting of head origin using digitization points in :func:`mne.preprocessing.maxwell_filter`, by `Eric Larson`_.

mne/_fiff/tests/test_reference.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,11 @@ def test_set_eeg_reference_rest():
312312
same = [raw.ch_names.index(raw.info["bads"][0])]
313313
picks = np.setdiff1d(np.arange(len(raw.ch_names)), same)
314314
trans = None
315-
sphere = make_sphere_model("auto", "auto", raw.info)
315+
# Use fixed values from old sphere fit to reduce lines changed with fixed algorithm
316+
sphere = make_sphere_model(
317+
[-0.00413508, 0.01598787, 0.05175598],
318+
0.09100286249131773,
319+
)
316320
src = setup_volume_source_space(pos=20.0, sphere=sphere, exclude=30.0)
317321
assert src[0]["nuse"] == 223 # low but fast
318322
fwd = make_forward_solution(raw.info, trans, src, sphere)

mne/bem.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,10 +1073,10 @@ def get_fitting_dig(info, dig_kinds="auto", exclude_frontal=True, verbose=None):
10731073

10741074

10751075
@verbose
1076-
def _fit_sphere_to_headshape(info, dig_kinds, verbose=None):
1076+
def _fit_sphere_to_headshape(info, dig_kinds, *, verbose=None):
10771077
"""Fit a sphere to the given head shape."""
10781078
hsp = get_fitting_dig(info, dig_kinds)
1079-
radius, origin_head = _fit_sphere(np.array(hsp), disp=False)
1079+
radius, origin_head = _fit_sphere(np.array(hsp))
10801080
# compute origin in device coordinates
10811081
dev_head_t = info["dev_head_t"]
10821082
if dev_head_t is None:
@@ -1105,36 +1105,16 @@ def _fit_sphere_to_headshape(info, dig_kinds, verbose=None):
11051105
return radius, origin_head, origin_device
11061106

11071107

1108-
def _fit_sphere(points, disp="auto"):
1108+
def _fit_sphere(points):
11091109
"""Fit a sphere to an arbitrary set of points."""
1110-
if isinstance(disp, str) and disp == "auto":
1111-
disp = True if logger.level <= 20 else False
1112-
# initial guess for center and radius
1113-
radii = (np.max(points, axis=1) - np.min(points, axis=1)) / 2.0
1114-
radius_init = radii.mean()
1115-
center_init = np.median(points, axis=0)
1116-
1117-
# optimization
1118-
x0 = np.concatenate([center_init, [radius_init]])
1119-
1120-
def cost_fun(center_rad):
1121-
d = np.linalg.norm(points - center_rad[:3], axis=1) - center_rad[3]
1122-
d *= d
1123-
return d.sum()
1124-
1125-
def constraint(center_rad):
1126-
return center_rad[3] # radius must be >= 0
1127-
1128-
x_opt = fmin_cobyla(
1129-
cost_fun,
1130-
x0,
1131-
constraint,
1132-
rhobeg=radius_init,
1133-
rhoend=radius_init * 1e-6,
1134-
disp=disp,
1135-
)
1136-
1137-
origin, radius = x_opt[:3], x_opt[3]
1110+
# linear least-squares sphere fit, see for example
1111+
# https://stackoverflow.com/a/78909044
1112+
# TODO: At some point we should maybe reject outliers first...
1113+
A = np.c_[2 * points, np.ones((len(points), 1))]
1114+
b = (points**2).sum(axis=1)
1115+
x, _, _, _ = np.linalg.lstsq(A, b, rcond=1e-6)
1116+
origin = x[:3]
1117+
radius = np.sqrt(x[0] ** 2 + x[1] ** 2 + x[2] ** 2 + x[3])
11381118
return radius, origin
11391119

11401120

mne/channels/tests/test_montage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1453,7 +1453,7 @@ def renamer(x):
14531453
assert (orig_pos != new_pos).all()
14541454

14551455
r0 = _fit_sphere(new_pos)[1]
1456-
assert_allclose(r0, [-0.001021, 0.014554, 0.041404], atol=1e-4)
1456+
assert_allclose(r0, [-0.001043, 0.01469, 0.041448], atol=1e-4)
14571457
# spot check: Fp1 and Fpz
14581458
assert_allclose(
14591459
new_pos[:2],

mne/dipole.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1573,7 +1573,7 @@ def fit_dipole(
15731573
# Find the best-fitting sphere
15741574
inner_skull = _bem_find_surface(bem, "inner_skull")
15751575
inner_skull = inner_skull.copy()
1576-
R, r0 = _fit_sphere(inner_skull["rr"], disp=False)
1576+
R, r0 = _fit_sphere(inner_skull["rr"])
15771577
# r0 back to head frame for logging
15781578
r0 = apply_trans(mri_head_t["trans"], r0[np.newaxis, :])[0]
15791579
inner_skull["r0"] = r0

mne/io/curry/tests/test_curry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_read_raw_curry_rfDC(fname, tol, mock_dev_head_t, tmp_path):
280280
assert n_other == 15
281281
pos = np.array(pos)
282282
nn = np.array(nn)
283-
rad, origin = _fit_sphere(pos, disp=False)
283+
rad, origin = _fit_sphere(pos)
284284
assert 0.11 < rad < 0.13
285285
pos -= origin
286286
pos /= np.linalg.norm(pos, axis=1, keepdims=True)

mne/preprocessing/tests/test_fine_cal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def test_fine_cal_systems(system, tmp_path):
243243
int_order = 5
244244
corrs = (0.13, 0.0, 0.12)
245245
sfs = [4, 5, 125, 155]
246-
corr_tol = 0.3
246+
corr_tol = 0.34
247247
else:
248248
assert system == "triux", f"Unknown system {system}"
249249
raw = read_raw_fif(tri_fname)

mne/preprocessing/tests/test_maxwell.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@
173173

174174
def _assert_n_free(raw_sss, lower, upper=None):
175175
"""Check the DOF."""
176+
__tracebackhide__ = True
176177
upper = lower if upper is None else upper
177178
n_free = raw_sss.info["proc_history"][0]["max_info"]["sss_info"]["nfree"]
178179
assert lower <= n_free <= upper, f"nfree fail: {lower} <= {n_free} <= {upper}"
@@ -367,10 +368,9 @@ def test_other_systems():
367368
hsp_path = kit_dir / "test_hsp.txt"
368369
raw_kit = read_raw_kit(sqd_path, str(mrk_path), str(elp_path), str(hsp_path))
369370
with (
370-
pytest.warns(RuntimeWarning, match="more than 20 mm from head frame origin"),
371371
pytest.raises(NotImplementedError, match="Cannot create forward solution with"),
372372
):
373-
maxwell_filter(raw_kit)
373+
maxwell_filter(raw_kit, verbose=True)
374374
with catch_logging() as log:
375375
raw_sss = maxwell_filter(
376376
raw_kit, origin=(0.0, 0.0, 0.04), ignore_ref=True, verbose=True
@@ -381,24 +381,21 @@ def test_other_systems():
381381
raw_kit, origin=(0.0, 0.0, 0.04), ignore_ref=True, mag_scale="auto"
382382
)
383383
assert_allclose(raw_sss._data, raw_sss_auto._data)
384-
# The KIT origin fit is terrible
385-
with pytest.warns(RuntimeWarning, match="more than 20 mm"):
386-
with catch_logging() as log:
387-
pytest.raises(
388-
RuntimeError, maxwell_filter, raw_kit, ignore_ref=True, regularize=None
389-
) # bad condition
390-
raw_sss = maxwell_filter(
391-
raw_kit,
392-
origin="auto",
393-
ignore_ref=True,
394-
bad_condition="info",
395-
verbose=True,
396-
)
384+
with catch_logging() as log:
385+
pytest.raises(
386+
RuntimeError, maxwell_filter, raw_kit, ignore_ref=True, regularize=None
387+
) # bad condition
388+
raw_sss = maxwell_filter(
389+
raw_kit,
390+
origin="auto",
391+
ignore_ref=True,
392+
bad_condition="info",
393+
verbose=True,
394+
)
397395
log = log.getvalue()
398-
assert "badly conditioned" in log
399-
assert "more than 20 mm from" in log
400-
# fits can differ slightly based on scipy version, so be lenient here
401-
_assert_n_free(raw_sss, 28, 34) # bad origin == brutal reg
396+
assert "badly conditioned" not in log
397+
assert "more than 20 mm from" not in log
398+
_assert_n_free(raw_sss, 67, 67)
402399
# Let's set the origin
403400
with catch_logging() as log:
404401
raw_sss = maxwell_filter(

mne/simulation/tests/test_raw.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,11 @@ def raw_data():
220220

221221
src = read_source_spaces(src_fname)
222222
trans = read_trans(trans_fname)
223-
sphere = make_sphere_model("auto", "auto", raw.info)
223+
# Use fixed values from old sphere fit to reduce lines changed with fixed algorithm
224+
sphere = make_sphere_model(
225+
[-0.00413508, 0.01598787, 0.05175598],
226+
0.09100286249131773,
227+
)
224228
stc = _make_stc(raw, src)
225229
return raw, src, stc, trans, sphere
226230

@@ -385,9 +389,9 @@ def test_simulate_raw_bem(raw_data):
385389
locs = np.concatenate([s["rr"][s["vertno"]] for s in src])
386390
tmax = (len(locs) - 1) / raw.info["sfreq"]
387391
cov = make_ad_hoc_cov(raw.info)
388-
# The tolerance for the BEM is surprisingly high (28) but I get the same
392+
# The tolerance for the BEM is surprisingly high but I get the same
389393
# result when using MNE-C and Xfit, even when using a proper 5120 BEM :(
390-
for use_raw, bem, tol in ((raw_sim_sph, sphere, 3), (raw_sim_bem, bem_fname, 31)):
394+
for use_raw, bem, tol in ((raw_sim_sph, sphere, 4), (raw_sim_bem, bem_fname, 31)):
391395
events = find_events(use_raw, "STI 014")
392396
assert len(locs) == 6
393397
evoked = Epochs(use_raw, events, 1, 0, tmax, baseline=None).average()

mne/surface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def get_meg_helmet_surf(info, trans=None, *, verbose=None):
220220
)
221221
hull = ConvexHull(rr)
222222
rr = rr[np.unique(hull.simplices)]
223-
R, center = _fit_sphere(rr, disp=False)
223+
R, center = _fit_sphere(rr)
224224
sph = _cart_to_sph(rr - center)[:, 1:]
225225
# add a point at the front of the helmet (where the face should be):
226226
# 90 deg az and maximal el (down from Z/up axis)

0 commit comments

Comments
 (0)