Skip to content

Commit 576d85b

Browse files
larsonersnwnde
authored andcommitted
BUG: Fix bug with plot_projs_topomap (mne-tools#11792)
1 parent 8b73696 commit 576d85b

File tree

2 files changed

+46
-30
lines changed

2 files changed

+46
-30
lines changed

mne/viz/tests/test_topomap.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
create_info,
2828
read_cov,
2929
EvokedArray,
30+
compute_proj_raw,
3031
Projection,
3132
)
3233
from mne._fiff.proj import make_eeg_average_ref_proj
@@ -71,6 +72,8 @@
7172
layout = read_layout("Vectorview-all")
7273
cov_fname = base_dir / "test-cov.fif"
7374

75+
fast_test = dict(res=8, contours=0, sensors=False)
76+
7477

7578
@pytest.mark.parametrize("constrained_layout", (False, True))
7679
def test_plot_topomap_interactive(constrained_layout):
@@ -135,32 +138,36 @@ def test_plot_projs_topomap():
135138
"""Test plot_projs_topomap."""
136139
projs = read_proj(ecg_fname)
137140
info = read_info(raw_fname)
138-
fast_test = {"res": 8, "contours": 0, "sensors": False}
139141
plot_projs_topomap(projs, info=info, colorbar=True, **fast_test)
140-
plt.close("all")
141-
ax = plt.subplot(111)
142+
_, ax = plt.subplots()
142143
projs[3].plot_topomap(info)
143144
plot_projs_topomap(projs[:1], info, axes=ax, **fast_test) # test axes
144-
plt.close("all")
145145
triux_info = read_info(triux_fname)
146146
plot_projs_topomap(triux_info["projs"][-1:], triux_info, **fast_test)
147-
plt.close("all")
148147
plot_projs_topomap(triux_info["projs"][:1], triux_info, **fast_test)
149-
plt.close("all")
150148
eeg_avg = make_eeg_average_ref_proj(info)
151149
eeg_avg.plot_topomap(info, **fast_test)
152-
plt.close("all")
153150
# test vlims
154151
for vlim in ("joint", (-1, 1), (None, 0.5), (0.5, None), (None, None)):
155152
plot_projs_topomap(projs[:-1], info, vlim=vlim, colorbar=True)
156-
plt.close("all")
157153

158154
eeg_proj = make_eeg_average_ref_proj(info)
159155
info_meg = pick_info(info, pick_types(info, meg=True, eeg=False))
160156
with pytest.raises(ValueError, match="Missing channels"):
161157
plot_projs_topomap([eeg_proj], info_meg)
162158

163159

160+
@pytest.mark.parametrize("vlim", ("joint", None))
161+
@pytest.mark.parametrize("meg", ("combined", "separate"))
162+
def test_plot_projs_topomap_joint(meg, vlim, raw):
163+
"""Test that plot_projs_topomap works with joint vlim."""
164+
if vlim is None:
165+
vlim = (None, None)
166+
projs = compute_proj_raw(raw, meg=meg)
167+
fig = plot_projs_topomap(projs, info=raw.info, vlim=vlim, **fast_test)
168+
assert len(fig.axes) == 4 # 2 mag, 2 grad
169+
170+
164171
def test_plot_topomap_animation(capsys):
165172
"""Test topomap plotting."""
166173
# evoked
@@ -322,7 +329,6 @@ def test_plot_topomap_basic():
322329
"""Test basics of topomap plotting."""
323330
evoked = read_evokeds(evoked_fname, "Left Auditory", baseline=(None, 0))
324331
res = 8
325-
fast_test = dict(res=res, contours=0, sensors=False, time_unit="s")
326332
fast_test_noscale = dict(res=res, contours=0, sensors=False)
327333
ev_bad = evoked.copy().pick(picks="eeg")
328334
ev_bad.pick(ev_bad.ch_names[:2])
@@ -649,8 +655,6 @@ def test_plot_arrowmap(evoked):
649655
@testing.requires_testing_data
650656
def test_plot_topomap_neuromag122():
651657
"""Test topomap plotting."""
652-
res = 8
653-
fast_test = dict(res=res, contours=0, sensors=False)
654658
evoked = read_evokeds(evoked_fname, "Left Auditory", baseline=(None, 0))
655659
evoked.pick(picks="grad")
656660
evoked.pick(evoked.ch_names[:122])

mne/viz/topomap.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -474,34 +474,46 @@ def _plot_projs_topomap(
474474
projs = _check_type_projs(projs)
475475
_validate_type(info, "info", "info")
476476

477-
types, datas, poss, spheres, outliness, ch_typess = [], [], [], [], [], []
477+
# Preprocess projs to deal with joint MEG projectors. If we duplicate these and
478+
# split into mag and grad, they should work as expected
479+
info_names = _clean_names(info["ch_names"], remove_whitespace=True)
480+
use_projs = list()
481+
for proj in projs:
482+
proj = _eliminate_zeros(proj) # gh 5641, makes a copy
483+
proj["data"]["col_names"] = _clean_names(
484+
proj["data"]["col_names"],
485+
remove_whitespace=True,
486+
)
487+
picks = pick_channels(info_names, proj["data"]["col_names"], ordered=True)
488+
proj_types = info.get_channel_types(picks)
489+
unique_types = sorted(set(proj_types))
490+
for type_ in unique_types:
491+
proj_picks = np.where([proj_type == type_ for proj_type in proj_types])[0]
492+
use_projs.append(copy.deepcopy(proj))
493+
use_projs[-1]["data"]["data"] = proj["data"]["data"][:, proj_picks]
494+
use_projs[-1]["data"]["col_names"] = [
495+
proj["data"]["col_names"][pick] for pick in proj_picks
496+
]
497+
projs = use_projs
498+
499+
datas, poss, spheres, outliness, ch_typess = [], [], [], [], []
478500
for proj in projs:
479501
# get ch_names, ch_types, data
480-
proj = _eliminate_zeros(proj) # gh 5641
481-
ch_names = _clean_names(proj["data"]["col_names"], remove_whitespace=True)
482-
if vlim == "joint":
483-
ch_idxs = np.where(np.isin(info["ch_names"], proj["data"]["col_names"]))[0]
484-
these_ch_types = info.get_channel_types(ch_idxs, unique=True)
485-
# each projector should have only one channel type
486-
assert len(these_ch_types) == 1
487-
types.append(list(these_ch_types)[0])
488502
data = proj["data"]["data"].ravel()
489-
info_names = _clean_names(info["ch_names"], remove_whitespace=True)
490-
picks = pick_channels(info_names, ch_names, ordered=True)
503+
picks = pick_channels(info_names, proj["data"]["col_names"], ordered=True)
491504
use_info = pick_info(info, picks)
505+
these_ch_types = use_info.get_channel_types(unique=True)
506+
assert len(these_ch_types) == 1 # should be guaranteed above
507+
ch_type = these_ch_types[0]
492508
(
493509
data_picks,
494510
pos,
495511
merge_channels,
496512
names,
497-
ch_type,
513+
_,
498514
this_sphere,
499515
clip_origin,
500-
) = _prepare_topomap_plot(
501-
use_info,
502-
_get_plot_ch_type(use_info, None),
503-
sphere=sphere,
504-
)
516+
) = _prepare_topomap_plot(use_info, ch_type, sphere=sphere)
505517
these_outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)
506518
data = data[data_picks]
507519
if merge_channels:
@@ -530,8 +542,8 @@ def _plot_projs_topomap(
530542
# handle vmin/vmax
531543
vlims = [None for _ in range(len(datas))]
532544
if vlim == "joint":
533-
for _ch_type in set(types):
534-
idx = np.where(np.isin(types, _ch_type))[0]
545+
for _ch_type in set(ch_typess):
546+
idx = np.where(np.isin(ch_typess, _ch_type))[0]
535547
these_data = np.concatenate(np.array(datas, dtype=object)[idx])
536548
norm = all(these_data >= 0)
537549
_vl = _setup_vmin_vmax(these_data, vmin=None, vmax=None, norm=norm)

0 commit comments

Comments
 (0)