Skip to content

Commit a297c4f

Browse files
authored
[MRG] [BUG] [ENH] [WIP] Bug fixes and enhancements for time-resolved spectral connectivity estimation (mne-tools#104)
1 parent 05f17b8 commit a297c4f

11 files changed

+568
-320
lines changed

doc/authors.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
.. _Szonja Weigl: https://github.com/weiglszonja
88
.. _Kenji Marshall: https://github.com/kenjimarshall
99
.. _Sezan Mert: https://github.com/SezanMert
10+
.. _Santeri Ruuskanen: https://github.com/ruuskas

doc/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@
197197
'use_edit_page_button': False,
198198
'navigation_with_keys': False,
199199
'show_toc_level': 1,
200-
'navbar_end': ['version-switcher', 'navbar-icon-links'],
200+
'navbar_end': ['theme-switcher', 'version-switcher', 'navbar-icon-links'],
201+
'secondary_sidebar_items': ['page-toc'],
201202
}
202203
# Custom sidebar templates, maps document names to template names.
203204
html_sidebars = {

doc/whats_new.rst

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,35 @@ Here we list a changelog of MNE-connectivity.
1919
Version 0.5 (Unreleased)
2020
------------------------
2121

22-
...
22+
This version has major changes in :func:`mne_connectivity.spectral_connectivity_time`. Several bugs are fixed, and the
23+
function now computes static connectivity over time, as opposed to static connectivity over trials computed by :func:`mne_connectivity.spectral_connectivity_epochs`.
2324

2425
Enhancements
2526
~~~~~~~~~~~~
2627

27-
-
28+
- Add the ``PLI`` and ``wPLI`` methods in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
29+
- Improve the documentation of :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
30+
- Add the option to average connectivity across epochs and frequencies in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
31+
- Select multitaper frequencies automatically in :func:`mne_connectivity.spectral_connectivity_time` similarly to :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`).
2832

2933
Bug
3034
~~~
3135

32-
-
36+
- When using the ``multitaper`` mode in :func:`mne_connectivity.spectral_connectivity_time`, average CSD over tapers instead of the complex signal by `Santeri Ruuskanen`_ (:gh:`104`).
37+
- Average over time when computing connectivity measures in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
38+
- Fix support for multiple connectivity methods in calls to :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
39+
- Fix bug with the ``indices`` parameter in :func:`mne_connectivity.spectral_connectivity_time`, the behavior is now as expected by `Santeri Ruuskanen`_ (:gh:`104`).
40+
- Fix bug with parallel computation in :func:`mne_connectivity.spectral_connectivity_time`, add instructions for memory mapping in doc by `Santeri Ruuskanen`_ (:gh:`104`).
3341

3442
API
3543
~~~
3644

37-
-
45+
- Streamline the API of :func:`mne_connectivity.spectral_connectivity_time` with :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`).
3846

3947
Authors
4048
~~~~~~~
4149

42-
*
50+
* `Santeri Ruuskanen`_
4351

4452
:doc:`Find out what was new in previous releases <whats_new_previous_releases>`
4553

environment.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: mne
1+
name: mne-connectivity
22
channels:
33
- conda-forge
44
dependencies:
@@ -20,5 +20,5 @@ dependencies:
2020
- pyvista>=0.32
2121
- pyvistaqt>=0.4
2222
- pyqt!=5.15.3
23-
- mne
23+
- mne>=1.0
2424
- h5netcdf

mne_connectivity/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,6 @@ def _check_skip_backend(name):
151151
if not has_imageio_ffmpeg():
152152
pytest.skip("Test skipped, requires imageio-ffmpeg")
153153
if name == 'pyvistaqt' and not _check_qt_version():
154-
pytest.skip("Test skipped, requires PyQt5.")
154+
pytest.skip("Test skipped, requires Python Qt bindings.")
155155
if name == 'pyvistaqt' and not has_pyvistaqt():
156156
pytest.skip("Test skipped, requires pyvistaqt")

mne_connectivity/spectral/epochs.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,
859859
860860
See Also
861861
--------
862+
mne_connectivity.spectral_connectivity_time
862863
mne_connectivity.SpectralConnectivity
863864
mne_connectivity.SpectroTemporalConnectivity
864865
@@ -873,7 +874,9 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,
873874
connectivity structure. Within each Epoch, it is assumed that the spectral
874875
measure is stationary. The spectral measures implemented in this function
875876
are computed across Epochs. **Thus, spectral measures computed with only
876-
one Epoch will result in errorful values.**
877+
one Epoch will result in errorful values and spectral measures computed
878+
with few Epochs will be unreliable.** Please see
879+
``spectral_connectivity_time`` for time-resolved connectivity estimation.
877880
878881
The spectral densities can be estimated using a multitaper method with
879882
digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier
@@ -891,11 +894,11 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,
891894
indices = (np.array([0, 0, 0]), # row indices
892895
np.array([2, 3, 4])) # col indices
893896
894-
con_flat = spectral_connectivity(data, method='coh',
895-
indices=indices, ...)
897+
con = spectral_connectivity_epochs(data, method='coh',
898+
indices=indices, ...)
896899
897-
In this case con_flat.shape = (3, n_freqs). The connectivity scores are
898-
in the same order as defined indices.
900+
In this case con.get_data().shape = (3, n_freqs). The connectivity scores
901+
are in the same order as defined indices.
899902
900903
**Supported Connectivity Measures**
901904

mne_connectivity/spectral/tests/test_spectral.py

Lines changed: 104 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,8 @@
22
from numpy.testing import (assert_allclose, assert_array_almost_equal,
33
assert_array_less)
44
import pytest
5-
import warnings
6-
7-
import mne
8-
from mne import (EpochsArray, SourceEstimate, create_info,
9-
make_fixed_length_epochs)
5+
from mne import (EpochsArray, SourceEstimate, create_info)
106
from mne.filter import filter_data
11-
from mne.utils import _resource_path
12-
from mne_bids import BIDSPath, read_raw_bids
137

148
from mne_connectivity import (
159
SpectralConnectivity, spectral_connectivity_epochs,
@@ -478,15 +472,109 @@ def test_epochs_tmin_tmax(kind):
478472
assert len(w) == 1 # just one even though there were multiple epochs
479473

480474

481-
@pytest.mark.parametrize('method', ['coh', 'plv'])
475+
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
476+
@pytest.mark.parametrize(
477+
'mode', ['cwt_morlet', 'multitaper'])
478+
@pytest.mark.parametrize('data_option', ['sync', 'random'])
479+
def test_spectral_connectivity_time_phaselocked(method, mode, data_option):
480+
"""Test time-resolved spectral connectivity with simulated phase-locked
481+
data."""
482+
rng = np.random.default_rng(0)
483+
n_epochs = 5
484+
n_channels = 3
485+
n_times = 1000
486+
sfreq = 250
487+
data = np.zeros((n_epochs, n_channels, n_times))
488+
if data_option == 'random':
489+
# Data is random, there should be no consistent phase differences.
490+
data = rng.random((n_epochs, n_channels, n_times))
491+
if data_option == 'sync':
492+
# Data consists of phase-locked 10Hz sine waves with constant phase
493+
# difference within each epoch.
494+
wave_freq = 10
495+
epoch_length = n_times / sfreq
496+
for i in range(n_epochs):
497+
for c in range(n_channels):
498+
phase = rng.random() * 10
499+
x = np.linspace(-wave_freq * epoch_length * np.pi + phase,
500+
wave_freq * epoch_length * np.pi + phase,
501+
n_times)
502+
data[i, c] = np.squeeze(np.sin(x))
503+
# the frequency band should contain the frequency at which there is a
504+
# hypothesized "connection"
505+
freq_band_low_limit = (8.)
506+
freq_band_high_limit = (13.)
507+
cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1)
508+
con = spectral_connectivity_time(data, method=method, mode=mode,
509+
sfreq=sfreq, fmin=freq_band_low_limit,
510+
fmax=freq_band_high_limit,
511+
cwt_freqs=cwt_freqs, n_jobs=1,
512+
faverage=True, average=True, sm_times=0)
513+
assert con.shape == (n_channels ** 2, len(con.freqs))
514+
con_matrix = con.get_data('dense')[..., 0]
515+
if data_option == 'sync':
516+
# signals are perfectly phase-locked, connectivity matrix should be
517+
# a lower triangular matrix of ones
518+
assert np.allclose(con_matrix,
519+
np.tril(np.ones(con_matrix.shape),
520+
k=-1),
521+
atol=0.01)
522+
if data_option == 'random':
523+
# signals are random, all connectivity values should be small
524+
# 0.5 is picked rather arbitrarily such that the obsolete wrong
525+
# implementation fails
526+
assert np.all(con_matrix) <= 0.5
527+
528+
529+
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
530+
@pytest.mark.parametrize(
531+
'cwt_freqs', [[8., 10.], [8, 10], 10., 10])
532+
def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs):
533+
"""Test time-resolved spectral connectivity with int and float values for
534+
cwt_freqs."""
535+
rng = np.random.default_rng(0)
536+
n_epochs = 5
537+
n_channels = 3
538+
n_times = 1000
539+
sfreq = 250
540+
data = np.zeros((n_epochs, n_channels, n_times))
541+
542+
# Data consists of phase-locked 10Hz sine waves with constant phase
543+
# difference within each epoch.
544+
wave_freq = 10
545+
epoch_length = n_times / sfreq
546+
for i in range(n_epochs):
547+
for c in range(n_channels):
548+
phase = rng.random() * 10
549+
x = np.linspace(-wave_freq * epoch_length * np.pi + phase,
550+
wave_freq * epoch_length * np.pi + phase,
551+
n_times)
552+
data[i, c] = np.squeeze(np.sin(x))
553+
# the frequency band should contain the frequency at which there is a
554+
# hypothesized "connection"
555+
con = spectral_connectivity_time(data, method=method, mode='cwt_morlet',
556+
sfreq=sfreq, fmin=np.min(cwt_freqs),
557+
fmax=np.max(cwt_freqs),
558+
cwt_freqs=cwt_freqs, n_jobs=1,
559+
faverage=True, average=True, sm_times=0)
560+
assert con.shape == (n_channels ** 2, len(con.freqs))
561+
con_matrix = con.get_data('dense')[..., 0]
562+
563+
# signals are perfectly phase-locked, connectivity matrix should be
564+
# a lower triangular matrix of ones
565+
assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1),
566+
atol=0.01)
567+
568+
569+
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
482570
@pytest.mark.parametrize(
483571
'mode', ['cwt_morlet', 'multitaper'])
484572
def test_spectral_connectivity_time_resolved(method, mode):
485573
"""Test time-resolved spectral connectivity."""
486574
sfreq = 50.
487575
n_signals = 3
488576
n_epochs = 2
489-
n_times = 256
577+
n_times = 1000
490578
trans_bandwidth = 2.
491579
tmin = 0.
492580
tmax = (n_times - 1) / sfreq
@@ -502,22 +590,21 @@ def test_spectral_connectivity_time_resolved(method, mode):
502590

503591
# define some frequencies for cwt
504592
freqs = np.arange(3, 20.5, 1)
505-
n_freqs = len(freqs)
506593

507594
# run connectivity estimation
508595
con = spectral_connectivity_time(
509-
data, freqs=freqs, method=method, mode=mode)
510-
assert con.shape == (n_epochs, n_signals * 2, n_freqs, n_times)
596+
data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode,
597+
n_cycles=5)
598+
assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs))
511599
assert con.get_data(output='dense').shape == \
512-
(n_epochs, n_signals, n_signals, n_freqs, n_times)
513-
514-
# average over time
515-
conn_data = con.get_data(output='dense').mean(axis=-1)
516-
conn_data = conn_data.mean(axis=-1)
600+
(n_epochs, n_signals, n_signals, len(con.freqs))
517601

518602
# test the simulated signal
519603
triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T
520604

605+
# average over frequencies
606+
conn_data = con.get_data(output='dense').mean(axis=-1)
607+
521608
# the indices at which there is a correlation should be greater
522609
# then the rest of the components
523610
for epoch_idx in range(n_epochs):
@@ -526,95 +613,6 @@ def test_spectral_connectivity_time_resolved(method, mode):
526613
for idx, jdx in triu_inds)
527614

528615

529-
@pytest.mark.parametrize('method', ['coh', 'plv'])
530-
@pytest.mark.parametrize(
531-
'mode', ['morlet', 'multitaper'])
532-
def test_time_resolved_spectral_conn_regression(method, mode):
533-
"""Regression test against original implementation in Frites.
534-
535-
To see how the test dataset was generated, see
536-
``benchmarks/single_epoch_conn.py``.
537-
"""
538-
test_file_path_str = str(_resource_path(
539-
'mne_connectivity.tests',
540-
f'data/test_frite_dataset_{mode}_{method}.npy'))
541-
test_conn = np.load(test_file_path_str)
542-
543-
# paths to mne datasets - sample ECoG
544-
bids_root = mne.datasets.epilepsy_ecog.data_path()
545-
546-
# first define the BIDS path and load in the dataset
547-
bids_path = BIDSPath(root=bids_root, subject='pt1', session='presurgery',
548-
task='ictal', datatype='ieeg', extension='.vhdr')
549-
with warnings.catch_warnings():
550-
warnings.simplefilter("ignore")
551-
raw = read_raw_bids(bids_path=bids_path, verbose=False)
552-
line_freq = raw.info['line_freq']
553-
554-
# Pick only the ECoG channels, removing the ECG channels
555-
raw.pick_types(ecog=True)
556-
557-
# drop bad channels
558-
raw.drop_channels(raw.info['bads'])
559-
560-
# only pick the first three channels to lower RAM usage
561-
raw = raw.pick_channels(raw.ch_names[:3])
562-
563-
# Load the data
564-
raw.load_data()
565-
566-
# Then we remove line frequency interference
567-
raw.notch_filter(line_freq)
568-
569-
# crop data and then Epoch
570-
raw_copy = raw.copy()
571-
raw = raw.crop(tmin=0, tmax=4, include_tmax=False)
572-
epochs = make_fixed_length_epochs(raw=raw, duration=2., overlap=1.)
573-
574-
######################################################################
575-
# Perform basic test to match simulation data using time-resolved spec
576-
######################################################################
577-
# compare data to original run using Frites
578-
freqs = [30, 90]
579-
580-
# mode was renamed in mne-connectivity
581-
if mode == 'morlet':
582-
mode = 'cwt_morlet'
583-
conn = spectral_connectivity_time(
584-
epochs, freqs=freqs, n_jobs=1, method=method, mode=mode)
585-
586-
# frites only stores the upper triangular parts of the raveled array
587-
row_triu_inds, col_triu_inds = np.triu_indices(len(raw.ch_names), k=1)
588-
conn_data = conn.get_data(output='dense')[
589-
:, row_triu_inds, col_triu_inds, ...]
590-
assert_array_almost_equal(conn_data, test_conn)
591-
592-
######################################################################
593-
# Give varying set of frequency bands and frequencies to perform cWT
594-
######################################################################
595-
raw = raw_copy.crop(tmin=0, tmax=10, include_tmax=False)
596-
ch_names = epochs.ch_names
597-
epochs = make_fixed_length_epochs(raw=raw, duration=5, overlap=0.)
598-
599-
# sampling rate of my data
600-
sfreq = raw.info['sfreq']
601-
602-
# frequency bands of interest
603-
fois = np.array([[4, 8], [8, 12], [12, 16], [16, 32]])
604-
605-
# frequencies of Continuous Morlet Wavelet Transform
606-
freqs = np.arange(4., 32., 1)
607-
608-
# compute coherence
609-
cohs = spectral_connectivity_time(
610-
epochs, names=None, method=method, indices=None,
611-
sfreq=sfreq, foi=fois, sm_times=0.5, sm_freqs=1, sm_kernel='hanning',
612-
mode=mode, mt_bandwidth=None, freqs=freqs, n_cycles=5)
613-
assert cohs.get_data(output='dense').shape == (
614-
len(epochs), len(ch_names), len(ch_names), len(fois), len(epochs.times)
615-
)
616-
617-
618616
def test_save(tmp_path):
619617
"""Test saving results of spectral connectivity."""
620618
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)