Skip to content

Commit c461bad

Browse files
authored
Support complex output in EpochsSpectrum Welch (#11556)
1 parent 367ed0a commit c461bad

File tree

5 files changed

+97
-14
lines changed

5 files changed

+97
-14
lines changed

doc/changes/latest.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Enhancements
3232
- Add :meth:`mne.Info.save` to save an :class:`mne.Info` object to a fif file (:gh:`11401` by `Alex Rockhill`_)
3333
- Improved error message when downloads are corrupted for :func:`mne.datasets.sample.data_path` and related functions (:gh:`11407` by `Eric Larson`_)
3434
- Add support for ``skip_by_annotation`` in :func:`mne.io.Raw.notch_filter` (:gh:`11388` by `Mainak Jas`_)
35+
- Add support for ``output='complex'`` to :func:`mne.time_frequency.psd_array_welch` and when using ``method='welch'`` with :meth:`mne.Epochs.compute_psd` (:gh:`11556` by `Eric Larson`_)
3536
- Slightly adjusted the window title for :func:`mne.Epochs.plot` (:gh:`11419` by `Richard Höchenberger`_ and `Daniel McCloy`_)
3637
- Add :func:`mne.count_events` to count unique event types in a given event array (:gh:`11430` by `Clemens Brunner`_)
3738
- Add a video to :ref:`tut-freesurfer-mne` of a brain inflating from the pial surface to aid in understanding the inflated brain (:gh:`11440` by `Alex Rockhill`_)

mne/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def pytest_configure(config):
131131
ignore:Jupyter is migrating its paths.*:DeprecationWarning
132132
ignore:Widget\..* is deprecated\.:DeprecationWarning
133133
ignore:.*is deprecated in pyzmq.*:DeprecationWarning
134+
ignore:The `ipykernel.comm.Comm` class has been deprecated.*:DeprecationWarning
134135
# PySide6
135136
ignore:Enum value .* is marked as deprecated:DeprecationWarning
136137
ignore:Function.*is marked as deprecated, please check the documentation.*:DeprecationWarning

mne/time_frequency/psd.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _decomp_aggregate_mask(epoch, func, average, freq_sl):
5656
return spect
5757

5858

59-
def _spect_func(epoch, func, freq_sl, average):
59+
def _spect_func(epoch, func, freq_sl, average, *, output='power'):
6060
"""Aux function."""
6161
# Decide if we should split this to save memory or not, since doing
6262
# multiple calls will incur some performance overhead. Eventually we might
@@ -91,7 +91,7 @@ def _check_nfft(n, n_fft, n_per_seg, n_overlap):
9191
@verbose
9292
def psd_array_welch(x, sfreq, fmin=0, fmax=np.inf, n_fft=256, n_overlap=0,
9393
n_per_seg=None, n_jobs=None, average='mean',
94-
window='hamming', *, verbose=None):
94+
window='hamming', *, output='power', verbose=None):
9595
"""Compute power spectral density (PSD) using Welch's method.
9696
9797
Welch's method is described in :footcite:t:`Welch1967`.
@@ -122,6 +122,15 @@ def psd_array_welch(x, sfreq, fmin=0, fmax=np.inf, n_fft=256, n_overlap=0,
122122
%(window_psd)s
123123
124124
.. versionadded:: 0.22.0
125+
output : str
126+
The format of the returned ``psds`` array, ``'complex'`` or
127+
``'power'``:
128+
129+
* ``'power'`` : the power spectral density is returned.
130+
* ``'complex'`` : the complex fourier coefficients are returned per
131+
window.
132+
133+
.. versionadded:: 1.4.0
125134
%(verbose)s
126135
127136
Returns
@@ -145,6 +154,8 @@ def psd_array_welch(x, sfreq, fmin=0, fmax=np.inf, n_fft=256, n_overlap=0,
145154
.. footbibliography::
146155
"""
147156
_check_option('average', average, (None, False, 'mean', 'median'))
157+
_check_option('output', output, ('power', 'complex'))
158+
mode = 'complex' if output == 'complex' else 'psd'
148159
n_fft = _ensure_int(n_fft, "n_fft")
149160
n_overlap = _ensure_int(n_overlap, "n_overlap")
150161
if n_per_seg is not None:
@@ -178,10 +189,10 @@ def psd_array_welch(x, sfreq, fmin=0, fmax=np.inf, n_fft=256, n_overlap=0,
178189
from scipy.signal import spectrogram
179190
parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs)
180191
func = partial(spectrogram, noverlap=n_overlap, nperseg=n_per_seg,
181-
nfft=n_fft, fs=sfreq, window=window)
192+
nfft=n_fft, fs=sfreq, window=window, mode=mode)
182193
x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0]
183194
f_spect = parallel(my_spect_func(d, func=func, freq_sl=freq_sl,
184-
average=average)
195+
average=average, output=output)
185196
for d in x_splits)
186197
psds = np.concatenate(f_spect, axis=0)
187198
shape = dshape + (len(freqs),)

mne/time_frequency/spectrum.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def __init__(self, inst, method, fmin, fmax, tmin, tmax, picks,
252252
self._dims = ('channel', 'freq',)
253253
if method_kw.get('average', '') in (None, False):
254254
self._dims += ('segment',)
255-
if method_kw.get('output', '') == 'complex':
255+
if self._returns_complex_tapers(**method_kw):
256256
self._dims = self._dims[:-1] + ('taper',) + self._dims[-1:]
257257
# record data type (for repr and html_repr)
258258
self._data_type = ('Fourier Coefficients' if 'taper' in self._dims
@@ -316,7 +316,8 @@ def _repr_html_(self, caption=None):
316316

317317
def _check_values(self):
318318
"""Check PSD results for correct shape and bad values."""
319-
assert len(self._dims) == self._data.ndim
319+
assert len(self._dims) == self._data.ndim, \
320+
(self._dims, self._data.ndim)
320321
assert self._data.shape == self._shape
321322
# negative values OK if the spectrum is really fourier coefficients
322323
if 'taper' in self._dims:
@@ -333,13 +334,19 @@ def _check_values(self):
333334
warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}',
334335
UserWarning)
335336

337+
def _returns_complex_tapers(self, **method_kw):
338+
return (
339+
method_kw.get('output', '') == 'complex' and
340+
self.method == 'multitaper'
341+
)
342+
336343
def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose):
337344
# make the spectra
338345
result = self._psd_func(
339346
data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs,
340347
verbose=verbose)
341348
# assign ._data (handling unaggregated multitaper output)
342-
if method_kw.get('output', '') == 'complex':
349+
if self._returns_complex_tapers(**method_kw):
343350
fourier_coefs, freqs, weights = result
344351
self._data = fourier_coefs
345352
self._mt_weights = weights
@@ -357,7 +364,7 @@ def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose):
357364
method_kw)
358365
self._shape += (n_welch_segments,)
359366
# insert n_tapers
360-
if method_kw.get('output', '') == 'complex':
367+
if self._returns_complex_tapers(**method_kw):
361368
self._shape = (
362369
self._shape[:-1] + (self._mt_weights.size,) + self._shape[-1:])
363370
# we don't need these anymore, and they make save/load harder

mne/time_frequency/tests/test_spectrum.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from contextlib import nullcontext
12
from functools import partial
23

34
import numpy as np
45
import pytest
5-
from numpy.testing import assert_array_equal
6+
from numpy.testing import assert_array_equal, assert_allclose
67

8+
from mne import create_info, make_fixed_length_epochs
9+
from mne.io import RawArray
710
from mne import Annotations
811
from mne.time_frequency import read_spectrum
912
from mne.time_frequency.multitaper import _psd_from_mt
@@ -137,8 +140,12 @@ def _agg_helper(df, weights, group_cols):
137140

138141
@requires_pandas
139142
@pytest.mark.parametrize('long_format', (False, True))
140-
@pytest.mark.parametrize('method', ('welch', 'multitaper'))
141-
def test_unaggregated_spectrum_to_data_frame(raw, long_format, method):
143+
@pytest.mark.parametrize('method, output', [
144+
('welch', 'complex'),
145+
('welch', 'power'),
146+
('multitaper', 'complex'),
147+
])
148+
def test_unaggregated_spectrum_to_data_frame(raw, long_format, method, output):
142149
"""Test converting complex multitaper spectra to data frame."""
143150
from pandas.testing import assert_frame_equal
144151

@@ -149,8 +156,10 @@ def test_unaggregated_spectrum_to_data_frame(raw, long_format, method):
149156
.to_data_frame(long_format=long_format))
150157
# unaggregated welch or complex multitaper →
151158
# aggregate w/ pandas (to make sure we did reshaping right)
152-
kwargs = {'average': False} if method == 'welch' else {'output': 'complex'}
153-
spectrum = raw.compute_psd(method=method, **kwargs)
159+
kwargs = dict()
160+
if method == 'welch':
161+
kwargs.update(average=False, verbose='error')
162+
spectrum = raw.compute_psd(method=method, output=output, **kwargs)
154163
df = spectrum.to_data_frame(long_format=long_format)
155164
grouping_cols = ['freq']
156165
drop_cols = ['segment'] if method == 'welch' else ['taper']
@@ -169,7 +178,12 @@ def test_unaggregated_spectrum_to_data_frame(raw, long_format, method):
169178
gb = df.drop(columns=drop_cols).groupby(
170179
grouping_cols, as_index=False, observed=False)
171180
if method == 'welch':
172-
agg_df = gb.aggregate(np.nanmean)
181+
if output == 'complex':
182+
def _fun(x):
183+
return np.nanmean(np.abs(x))
184+
else:
185+
_fun = np.nanmean
186+
agg_df = gb.aggregate(_fun)
173187
else:
174188
agg_df = gb.apply(_agg_helper, spectrum._mt_weights, grouping_cols)
175189
# even with check_categorical=False, we know that the *data* matches;
@@ -235,3 +249,52 @@ def test_spectrum_proj(inst, request):
235249
with has_proj.info._unlock():
236250
has_proj.info['projs'] = no_proj.info['projs']
237251
assert has_proj == no_proj
252+
253+
254+
@pytest.mark.parametrize('method, average', [
255+
('welch', False),
256+
('welch', 'mean'),
257+
('multitaper', False),
258+
])
259+
def test_spectrum_complex(method, average):
260+
"""Test output='complex' support."""
261+
sfreq = 100
262+
n = 10 * sfreq
263+
freq = 3.
264+
phase = np.pi / 4 # should be recoverable
265+
data = np.cos(2 * np.pi * freq * np.arange(n) / sfreq + phase)[np.newaxis]
266+
raw = RawArray(data, create_info(1, sfreq, 'eeg'))
267+
epochs = make_fixed_length_epochs(raw, duration=2., preload=True)
268+
assert len(epochs) == 5
269+
assert len(epochs.times) == 2 * sfreq
270+
kwargs = dict(output='complex', method=method)
271+
if method == 'welch':
272+
kwargs['n_fft'] = sfreq
273+
ctx = pytest.warns(UserWarning, match='Zero value')
274+
want_dims = ('epoch', 'channel', 'freq')
275+
want_shape = (5, 1, sfreq // 2 + 1)
276+
if not average:
277+
want_dims = want_dims + ('segment',)
278+
want_shape = want_shape + (2,)
279+
kwargs['average'] = average
280+
else:
281+
assert method == 'multitaper'
282+
assert not average
283+
ctx = nullcontext()
284+
want_dims = ('epoch', 'channel', 'taper', 'freq')
285+
want_shape = (5, 1, 7, sfreq + 1)
286+
with ctx:
287+
spectrum = epochs.compute_psd(**kwargs)
288+
idx = np.argmin(np.abs(spectrum.freqs - freq))
289+
assert spectrum.freqs[idx] == freq
290+
assert spectrum._dims == want_dims
291+
assert spectrum.shape == want_shape
292+
data = spectrum.get_data()
293+
assert data.dtype == np.complex128
294+
coef = spectrum.get_data(fmin=freq, fmax=freq).mean(0)
295+
if method == 'multitaper':
296+
coef = coef[..., 0, :] # first taper
297+
elif not average:
298+
coef = coef.mean(-1) # over segments
299+
coef = coef.item()
300+
assert_allclose(np.angle(coef), phase, rtol=1e-4)

0 commit comments

Comments
 (0)