Skip to content

Commit fa91217

Browse files
jjnurminenagramfort
authored andcommitted
MRG: handle different time vecs in plot_evoked_topo() (#5788)
* handle different time vecs in plot_evoked_topo() * fix calls in raw.py and format_coord * comments * fix overlong line (dammit) * whatsnew * rm unnecessary list creation * fix auto ylim for differently shaped data * test plot_evoked_topo for nonuniform times * more specific what's new [ci skip]
1 parent d4c24cb commit fa91217

File tree

4 files changed

+44
-30
lines changed

4 files changed

+44
-30
lines changed

doc/whats_new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ Changelog
164164

165165
- Add parameter ``rank='full'`` to :func:`mne.beamformer.make_lcmv`, which can be set to ``None`` to auto-compute the rank of the covariance matrix before regularization by `Marijn van Vliet`_
166166

167+
- Handle different time vectors in topography plots using :func:`mne.viz.plot_evoked_topo` by `Jussi Nurminen`_
167168

168169
Bug
169170
~~~

mne/viz/raw.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,9 +1173,9 @@ def plot_raw_psd_topo(raw, tmin=0., tmax=None, fmin=0., fmax=100., proj=False,
11731173
else:
11741174
y_label = 'Power'
11751175
show_func = partial(_plot_timeseries_unified, data=[psds], color=color,
1176-
times=freqs)
1176+
times=[freqs])
11771177
click_func = partial(_plot_timeseries, data=[psds], color=color,
1178-
times=freqs)
1178+
times=[freqs])
11791179
picks = _pick_data_channels(raw.info)
11801180
info = pick_info(raw.info, picks)
11811181

mne/viz/tests/test_topo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,11 @@ def return_inds(d): # to test function kwarg to zorder arg of evoked.plot
153153
def test_plot_topo_single_ch():
154154
"""Test single channel topoplot with time cursor."""
155155
evoked = _get_epochs().average()
156-
fig = plot_evoked_topo(evoked, background_color='w')
156+
evoked2 = evoked.copy()
157+
# test plotting several evokeds on different time grids
158+
evoked.crop(-.19, 0)
159+
evoked2.crop(.05, .19)
160+
fig = plot_evoked_topo([evoked, evoked2], background_color='w')
157161
# test status bar message
158162
ax = plt.gca()
159163
assert ('MEG 0113' in ax.format_coord(.065, .63))

mne/viz/topo.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _plot_topo(info, times, show_func, click_func=None, layout=None,
184184
layout.pos[:, :2] /= layout.pos[:, :2].max(0)
185185

186186
# prepare callbacks
187-
tmin, tmax = times[[0, -1]]
187+
tmin, tmax = times[0], times[-1]
188188
click_func = show_func if click_func is None else click_func
189189
on_pick = partial(click_func, tmin=tmin, tmax=tmax, vmin=vmin,
190190
vmax=vmax, ylim=ylim, x_label=x_label,
@@ -342,13 +342,13 @@ def _plot_timeseries(ax, ch_idx, tmin, tmax, vmin, vmax, ylim, data, color,
342342
import matplotlib.pyplot as plt
343343
from matplotlib.colors import colorConverter
344344
picker_flag = False
345-
for data_, color_ in zip(data, color):
345+
for data_, color_, times_ in zip(data, color, times):
346346
if not picker_flag:
347347
# use large tol for picker so we can click anywhere in the axes
348-
ax.plot(times, data_[ch_idx], color=color_, picker=1e9)
348+
ax.plot(times_, data_[ch_idx], color=color_, picker=1e9)
349349
picker_flag = True
350350
else:
351-
ax.plot(times, data_[ch_idx], color=color_)
351+
ax.plot(times_, data_[ch_idx], color=color_)
352352

353353
if x_label is not None:
354354
ax.set(xlabel=x_label)
@@ -361,17 +361,26 @@ def _plot_timeseries(ax, ch_idx, tmin, tmax, vmin, vmax, ylim, data, color,
361361

362362
def _format_coord(x, y, labels, ax):
363363
"""Create status string based on cursor coordinates."""
364-
idx = np.abs(times - x).argmin()
364+
# find indices for datasets near cursor (if any)
365+
tdiffs = [np.abs(tvec - x).min() for tvec in times]
366+
nearby = [k for k, tdiff in enumerate(tdiffs) if
367+
tdiff < (tmax - tmin) / 100]
368+
timestr = '%6.3f s: ' % x
369+
if not nearby:
370+
return '%s Nothing here' % timestr
371+
nearby_data = [(data[n], labels[n], times[n]) for n in nearby]
365372
ylabel = ax.get_ylabel()
366373
unit = (ylabel[ylabel.find('(') + 1:ylabel.find(')')]
367374
if '(' in ylabel and ')' in ylabel else '')
368-
labels = [''] * len(data) if labels is None else labels
375+
labels = [''] * len(nearby_data) if labels is None else labels
369376
# try to estimate whether to truncate condition labels
370377
slen = 10 + sum([12 + len(unit) + len(label) for label in labels])
371378
bar_width = (ax.figure.get_size_inches() * ax.figure.dpi)[0] / 5.5
379+
# show labels and y values for datasets near cursor
372380
trunc_labels = bar_width < slen
373-
s = '%6.3f s: ' % times[idx]
374-
for data_, label in zip(data, labels):
381+
s = timestr
382+
for data_, label, tvec in nearby_data:
383+
idx = np.abs(tvec - x).argmin()
375384
s += '%7.2f %s' % (data_[ch_idx, idx], unit)
376385
if trunc_labels:
377386
label = (label if len(label) <= 10 else
@@ -437,16 +446,16 @@ def _plot_timeseries_unified(bn, ch_idx, tmin, tmax, vmin, vmax, ylim, data,
437446
"""Show multiple time series on topo using a single axes."""
438447
import matplotlib.pyplot as plt
439448
if not (ylim and not any(v is None for v in ylim)):
440-
ylim = np.array([np.min(data), np.max(data)])
449+
ylim = [min(np.min(d) for d in data), max(np.max(d) for d in data)]
441450
# Translation and scale parameters to take data->under_ax normalized coords
442451
_compute_scalings(bn, (tmin, tmax), ylim)
443452
pos = bn.pos
444453
data_lines = bn.data_lines
445454
ax = bn.ax
446455
# XXX These calls could probably be made faster by using collections
447-
for data_, color_ in zip(data, color):
456+
for data_, color_, times_ in zip(data, color, times):
448457
data_lines.append(ax.plot(
449-
bn.x_t + bn.x_s * times, bn.y_t + bn.y_s * data_[ch_idx],
458+
bn.x_t + bn.x_s * times_, bn.y_t + bn.y_s * data_[ch_idx],
450459
linewidth=0.5, color=color_, clip_on=True, clip_box=pos)[0])
451460
if vline:
452461
vline = np.array(vline) * bn.x_s + bn.x_t
@@ -631,10 +640,6 @@ def _plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
631640
else:
632641
color = cycle([color])
633642

634-
times = evoked[0].times
635-
if not all((e.times == times).all() for e in evoked):
636-
raise ValueError('All evoked.times must be the same')
637-
638643
noise_cov = _check_cov(noise_cov, evoked[0].info)
639644
if noise_cov is not None:
640645
evoked = [whiten_evoked(e, noise_cov) for e in evoked]
@@ -712,11 +717,10 @@ def _plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
712717
y_label.append('Amplitude (%s)' % unit)
713718

714719
if ylim is None:
715-
def set_ylim(x):
716-
return np.abs(x).max()
717-
ylim_ = [set_ylim([e.data[t] for e in evoked]) for t in picks]
718-
ymax = np.array(ylim_)
719-
ylim_ = (-ymax, ymax)
720+
# find maxima over all evoked data for each channel pick
721+
ymaxes = np.array([max(np.abs(e.data[t]).max() for e in evoked)
722+
for t in picks])
723+
ylim_ = (-ymaxes, ymaxes)
720724
elif isinstance(ylim, dict):
721725
ylim_ = _handle_default('ylim', ylim)
722726
ylim_ = [ylim_[kk] for kk in types_used]
@@ -730,20 +734,25 @@ def set_ylim(x):
730734

731735
data = [e.data for e in evoked]
732736
comments = [e.comment for e in evoked]
737+
times = [e.times for e in evoked]
738+
733739
show_func = partial(_plot_timeseries_unified, data=data, color=color,
734740
times=times, vline=vline, hline=hline,
735741
hvline_color=font_color)
736742
click_func = partial(_plot_timeseries, data=data, color=color, times=times,
737743
vline=vline, hline=hline, hvline_color=font_color,
738744
labels=comments)
739745

740-
fig = _plot_topo(info=info, times=times, show_func=show_func,
741-
click_func=click_func, layout=layout, colorbar=False,
742-
ylim=ylim_, cmap=None, layout_scale=layout_scale,
743-
border=border, fig_facecolor=fig_facecolor,
744-
font_color=font_color, axis_facecolor=axis_facecolor,
745-
title=title, x_label='Time (s)', y_label=y_label,
746-
unified=True, axes=axes)
746+
time_min = min([t[0] for t in times])
747+
time_max = max([t[-1] for t in times])
748+
fig = _plot_topo(info=info, times=[time_min, time_max],
749+
show_func=show_func, click_func=click_func, layout=layout,
750+
colorbar=False, ylim=ylim_, cmap=None,
751+
layout_scale=layout_scale, border=border,
752+
fig_facecolor=fig_facecolor, font_color=font_color,
753+
axis_facecolor=axis_facecolor, title=title,
754+
x_label='Time (s)', y_label=y_label, unified=True,
755+
axes=axes)
747756

748757
add_background_image(fig, fig_background)
749758

0 commit comments

Comments
 (0)