Skip to content

Commit 96d22f8

Browse files
tsbinnslarsoner
andauthored
Backport PR #13067 on branch maint/1.9 ([BUG] Fix taper weighting in computation of TFR multitaper power) (#13072)
Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
1 parent 672bdf4 commit 96d22f8

File tree

5 files changed

+51
-42
lines changed

5 files changed

+51
-42
lines changed

doc/changes/devel/13067.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug where taper weights were not correctly applied when computing multitaper power with :meth:`mne.Epochs.compute_tfr` and :func:`mne.time_frequency.tfr_array_multitaper`, by `Thomas Binns`_.

mne/export/_export.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ def export_raw(
2525
2626
%(export_warning)s
2727
28+
.. warning::
29+
When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the
30+
same as ``raw.annotations.orig_time``. This guarantees that the annotations are
31+
in the same reference frame as the samples. When
32+
:attr:`Raw.first_time <mne.io.Raw.first_time>` is not zero (e.g., after
33+
cropping), the onsets are automatically corrected so that onsets are always
34+
relative to the first sample.
35+
2836
Parameters
2937
----------
3038
%(fname_export_params)s
@@ -216,7 +224,6 @@ def _infer_check_export_fmt(fmt, fname, supported_formats):
216224

217225
supported_str = ", ".join(supported)
218226
raise ValueError(
219-
f"Format '{fmt}' is not supported. "
220-
f"Supported formats are {supported_str}."
227+
f"Format '{fmt}' is not supported. Supported formats are {supported_str}."
221228
)
222229
return fmt

mne/time_frequency/tests/test_tfr.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -255,22 +255,6 @@ def test_tfr_morlet():
255255
# computed within the method.
256256
assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data)
257257

258-
# test that averaging power across tapers when multitaper with
259-
# output='complex' gives the same as output='power'
260-
epoch_data = epochs.get_data()
261-
multitaper_power = tfr_array_multitaper(
262-
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power"
263-
)
264-
multitaper_complex = tfr_array_multitaper(
265-
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex"
266-
)
267-
268-
taper_dim = 2
269-
power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean(
270-
axis=taper_dim
271-
)
272-
assert_allclose(power_from_complex, multitaper_power)
273-
274258
print(itc) # test repr
275259
print(itc.ch_names) # test property
276260
itc += power # test add

mne/time_frequency/tfr.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def _make_dpss(
266266
The wavelets time series.
267267
"""
268268
Ws = list()
269+
Cs = list()
269270

270271
freqs = np.array(freqs)
271272
if np.any(freqs <= 0):
@@ -281,6 +282,7 @@ def _make_dpss(
281282

282283
for m in range(n_taps):
283284
Wm = list()
285+
Cm = list()
284286
for k, f in enumerate(freqs):
285287
if len(n_cycles) != 1:
286288
this_n_cycles = n_cycles[k]
@@ -302,12 +304,15 @@ def _make_dpss(
302304
real_offset = Wk.mean()
303305
Wk -= real_offset
304306
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
307+
Ck = np.sqrt(conc[m])
305308

306309
Wm.append(Wk)
310+
Cm.append(Ck)
307311

308312
Ws.append(Wm)
313+
Cs.append(Cm)
309314
if return_weights:
310-
return Ws, conc
315+
return Ws, Cs
311316
return Ws
312317

313318

@@ -529,15 +534,18 @@ def _compute_tfr(
529534
if method == "morlet":
530535
W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean)
531536
Ws = [W] # to have same dimensionality as the 'multitaper' case
537+
weights = None # no tapers for Morlet estimates
532538

533539
elif method == "multitaper":
534-
Ws = _make_dpss(
540+
Ws, weights = _make_dpss(
535541
sfreq,
536542
freqs,
537543
n_cycles=n_cycles,
538544
time_bandwidth=time_bandwidth,
539545
zero_mean=zero_mean,
546+
return_weights=True, # required for converting complex → power
540547
)
548+
weights = np.asarray(weights)
541549

542550
# Check wavelets
543551
if len(Ws[0][0]) > epoch_data.shape[2]:
@@ -560,7 +568,7 @@ def _compute_tfr(
560568
if ("avg_" in output) or ("itc" in output):
561569
out = np.empty((n_chans, n_freqs, n_times), dtype)
562570
elif output in ["complex", "phase"] and method == "multitaper":
563-
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
571+
out = np.empty((n_chans, n_epochs, n_tapers, n_freqs, n_times), dtype)
564572
else:
565573
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)
566574

@@ -571,7 +579,7 @@ def _compute_tfr(
571579

572580
# Parallelization is applied across channels.
573581
tfrs = parallel(
574-
my_cwt(channel, Ws, output, use_fft, "same", decim, method)
582+
my_cwt(channel, Ws, output, use_fft, "same", decim, weights)
575583
for channel in epoch_data.transpose(1, 0, 2)
576584
)
577585

@@ -581,10 +589,8 @@ def _compute_tfr(
581589

582590
if ("avg_" not in output) and ("itc" not in output):
583591
# This is to enforce that the first dimension is for epochs
584-
if output in ["complex", "phase"] and method == "multitaper":
585-
out = out.transpose(2, 0, 1, 3, 4)
586-
else:
587-
out = out.transpose(1, 0, 2, 3)
592+
out = np.moveaxis(out, 1, 0)
593+
588594
return out
589595

590596

@@ -658,7 +664,7 @@ def _check_tfr_param(
658664
return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim
659665

660666

661-
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
667+
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None):
662668
"""Aux. function to _compute_tfr.
663669
664670
Loops time-frequency transform across wavelets and epochs.
@@ -685,9 +691,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
685691
See numpy.convolve.
686692
decim : slice
687693
The decimation slice: e.g. power[:, decim]
688-
method : str | None
689-
Used only for multitapering to create tapers dimension in the output
690-
if ``output in ['complex', 'phase']``.
694+
weights : array, shape (n_tapers, n_wavelets) | None
695+
Concentration weights for each taper in the wavelets, if present.
691696
"""
692697
# Set output type
693698
dtype = np.float64
@@ -701,10 +706,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
701706
n_freqs = len(Ws[0])
702707
if ("avg_" in output) or ("itc" in output):
703708
tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
704-
elif output in ["complex", "phase"] and method == "multitaper":
705-
tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype)
709+
elif output in ["complex", "phase"] and weights is not None:
710+
tfrs = np.zeros((n_epochs, n_tapers, n_freqs, n_times), dtype=dtype)
706711
else:
707712
tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)
713+
if weights is not None:
714+
weights = np.expand_dims(weights, axis=-1) # add singleton time dimension
708715

709716
# Loops across tapers.
710717
for taper_idx, W in enumerate(Ws):
@@ -719,6 +726,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
719726
# Loop across epochs
720727
for epoch_idx, tfr in enumerate(coefs):
721728
# Transform complex values
729+
if output not in ["complex", "phase"] and weights is not None:
730+
tfr = weights[taper_idx] * tfr # weight each taper estimate
722731
if output in ["power", "avg_power"]:
723732
tfr = (tfr * tfr.conj()).real # power
724733
elif output == "phase":
@@ -734,8 +743,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
734743
# Stack or add
735744
if ("avg_" in output) or ("itc" in output):
736745
tfrs += tfr
737-
elif output in ["complex", "phase"] and method == "multitaper":
738-
tfrs[taper_idx, epoch_idx] += tfr
746+
elif output in ["complex", "phase"] and weights is not None:
747+
tfrs[epoch_idx, taper_idx] += tfr
739748
else:
740749
tfrs[epoch_idx] += tfr
741750

@@ -749,9 +758,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
749758
if ("avg_" in output) or ("itc" in output):
750759
tfrs /= n_epochs
751760

752-
# Normalization by number of taper
753-
if n_tapers > 1 and output not in ["complex", "phase"]:
754-
tfrs /= n_tapers
761+
# Normalization by taper weights
762+
if n_tapers > 1 and output not in ["complex", "phase", "itc"]:
763+
if "avg_" not in output: # add singleton epochs dimension to weights
764+
weights = np.expand_dims(weights, axis=0)
765+
tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3)
766+
if output == "avg_power_itc": # weight itc by the number of tapers
767+
tfrs.imag = tfrs.imag / n_tapers
768+
755769
return tfrs
756770

757771

mne/utils/docs.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,19 +1494,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
14941494

14951495
docdict["export_fmt_support_epochs"] = """\
14961496
Supported formats:
1497-
- EEGLAB (``.set``, uses :mod:`eeglabio`)
1497+
1498+
- EEGLAB (``.set``, uses :mod:`eeglabio`)
14981499
"""
14991500

15001501
docdict["export_fmt_support_evoked"] = """\
15011502
Supported formats:
1502-
- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`)
1503+
1504+
- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`)
15031505
"""
15041506

15051507
docdict["export_fmt_support_raw"] = """\
15061508
Supported formats:
1507-
- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv <https://github.com/bids-standard/pybv>`_)
1508-
- EEGLAB (``.set``, uses :mod:`eeglabio`)
1509-
- EDF (``.edf``, uses `edfio <https://github.com/the-siesta-group/edfio>`_)
1509+
1510+
- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv <https://github.com/bids-standard/pybv>`_)
1511+
- EEGLAB (``.set``, uses :mod:`eeglabio`)
1512+
- EDF (``.edf``, uses `edfio <https://github.com/the-siesta-group/edfio>`_)
15101513
""" # noqa: E501
15111514

15121515
docdict["export_warning"] = """\

0 commit comments

Comments
 (0)