-
Notifications
You must be signed in to change notification settings - Fork 35
[MRG] [BUG] [ENH] [WIP] Bug fixes and enhancements for time-resolved spectral connectivity estimation #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 44 commits
41bd87b
05ae8f6
9cc283c
924b4ab
c5b75f8
7699983
89f683a
cb2dab5
0fd34c9
f4992a7
cc1d69d
fa9bce7
9e9f983
2e4ccd1
926c330
b881c34
73fb937
baaca26
2a7e09e
2a0be06
53ac7b5
bc61e54
7ce13bd
c7dd18c
4d2c1f0
1b6224f
054512b
283a1a1
083aa1b
e89ac77
19603bd
e5da3ae
39f0ee6
4b6311c
7c64633
f6684c0
39fef93
08b5c79
ccb0a2d
fe727f7
3b966ec
f082d6f
84d073b
3e7f208
71b61ad
9e073c6
dcfbc8b
fb9869f
eec0113
b2dda41
099f194
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,14 +2,8 @@ | |
from numpy.testing import (assert_allclose, assert_array_almost_equal, | ||
assert_array_less) | ||
import pytest | ||
import warnings | ||
|
||
import mne | ||
from mne import (EpochsArray, SourceEstimate, create_info, | ||
make_fixed_length_epochs) | ||
from mne import (EpochsArray, SourceEstimate, create_info) | ||
from mne.filter import filter_data | ||
from mne.utils import _resource_path | ||
from mne_bids import BIDSPath, read_raw_bids | ||
|
||
from mne_connectivity import ( | ||
SpectralConnectivity, spectral_connectivity_epochs, | ||
|
@@ -478,15 +472,107 @@ def test_epochs_tmin_tmax(kind): | |
assert len(w) == 1 # just one even though there were multiple epochs | ||
|
||
|
||
@pytest.mark.parametrize('method', ['coh', 'plv']) | ||
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) | ||
@pytest.mark.parametrize( | ||
'mode', ['cwt_morlet', 'multitaper']) | ||
@pytest.mark.parametrize('data_option', ['sync', 'random']) | ||
def test_spectral_connectivity_time_phaselocked(method, mode, data_option): | ||
"""Test time-resolved spectral connectivity with simulated phase-locked data.""" | ||
rng = np.random.default_rng(0) | ||
n_epochs = 5 | ||
n_channels = 3 | ||
n_times = 1000 | ||
sfreq = 250 | ||
data = np.zeros((n_epochs, n_channels, n_times)) | ||
if data_option == 'random': | ||
# Data is random, there should be no consistent phase differences. | ||
data = rng.random((n_epochs, n_channels, n_times)) | ||
if data_option == 'sync': | ||
# Data consists of phase-locked 10Hz sine waves with constant phase | ||
# difference within each epoch. | ||
wave_freq = 10 | ||
epoch_length = n_times / sfreq | ||
for i in range(n_epochs): | ||
for c in range(n_channels): | ||
phase = rng.random() * 10 | ||
x = np.linspace(-wave_freq * epoch_length * np.pi + phase, | ||
wave_freq * epoch_length * np.pi + phase, | ||
n_times) | ||
data[i, c] = np.squeeze(np.sin(x)) | ||
# the frequency band should contain the frequency at which there is a hypothesized "connection" | ||
freq_band_low_limit = (8.) | ||
freq_band_high_limit = (13.) | ||
cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) | ||
con = spectral_connectivity_time(data, method=method, mode=mode, | ||
sfreq=sfreq, fmin=freq_band_low_limit, | ||
fmax=freq_band_high_limit, | ||
cwt_freqs=cwt_freqs, n_jobs=1, | ||
faverage=True, average=True, sm_times=0) | ||
assert con.shape == (n_channels ** 2, len(con.freqs)) | ||
con_matrix = con.get_data('dense')[..., 0] | ||
if data_option == 'sync': | ||
# signals are perfectly phase-locked, connectivity matrix should be | ||
# a lower triangular matrix of ones | ||
assert np.allclose(con_matrix, | ||
np.tril(np.ones(con_matrix.shape), | ||
k=-1), | ||
atol=0.01) | ||
if data_option == 'random': | ||
# signals are random, all connectivity values should be small | ||
# 0.5 is picked rather arbitrarily such that the obsolete wrong | ||
# implementation fails | ||
assert np.all(con_matrix) <= 0.5 | ||
|
||
|
||
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) | ||
@pytest.mark.parametrize( | ||
'cwt_freqs', [[8., 10.], [8, 10], 10., 10]) | ||
def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): | ||
"""Test time-resolved spectral connectivity with int and float values for | ||
cwt_freqs.""" | ||
rng = np.random.default_rng(0) | ||
n_epochs = 5 | ||
n_channels = 3 | ||
n_times = 1000 | ||
sfreq = 250 | ||
data = np.zeros((n_epochs, n_channels, n_times)) | ||
|
||
# Data consists of phase-locked 10Hz sine waves with constant phase | ||
# difference within each epoch. | ||
wave_freq = 10 | ||
epoch_length = n_times / sfreq | ||
for i in range(n_epochs): | ||
for c in range(n_channels): | ||
phase = rng.random() * 10 | ||
x = np.linspace(-wave_freq * epoch_length * np.pi + phase, | ||
wave_freq * epoch_length * np.pi + phase, | ||
n_times) | ||
data[i, c] = np.squeeze(np.sin(x)) | ||
# the frequency band should contain the frequency at which there is a | ||
# hypothesized "connection" | ||
con = spectral_connectivity_time(data, method=method, mode='cwt_morlet', | ||
sfreq=sfreq, fmin=np.min(cwt_freqs), | ||
fmax=np.max(cwt_freqs), | ||
cwt_freqs=cwt_freqs, n_jobs=1, | ||
faverage=True, average=True, sm_times=0) | ||
assert con.shape == (n_channels ** 2, len(con.freqs)) | ||
con_matrix = con.get_data('dense')[..., 0] | ||
|
||
# signals are perfectly phase-locked, connectivity matrix should be | ||
# a lower triangular matrix of ones | ||
assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1), | ||
atol=0.01) | ||
|
||
|
||
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) | ||
@pytest.mark.parametrize( | ||
'mode', ['cwt_morlet', 'multitaper']) | ||
def test_spectral_connectivity_time_resolved(method, mode): | ||
"""Test time-resolved spectral connectivity.""" | ||
sfreq = 50. | ||
n_signals = 3 | ||
n_epochs = 2 | ||
n_times = 256 | ||
n_times = 1000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we increase the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was an issue with the length of the wavelets being longer than the signal at some point during testing. Now it appears that the earlier value 256 would work just fine. |
||
trans_bandwidth = 2. | ||
tmin = 0. | ||
tmax = (n_times - 1) / sfreq | ||
|
@@ -502,22 +588,21 @@ def test_spectral_connectivity_time_resolved(method, mode): | |
|
||
# define some frequencies for cwt | ||
freqs = np.arange(3, 20.5, 1) | ||
n_freqs = len(freqs) | ||
|
||
# run connectivity estimation | ||
con = spectral_connectivity_time( | ||
data, freqs=freqs, method=method, mode=mode) | ||
assert con.shape == (n_epochs, n_signals * 2, n_freqs, n_times) | ||
data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode, | ||
n_cycles=5) | ||
assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) | ||
assert con.get_data(output='dense').shape == \ | ||
(n_epochs, n_signals, n_signals, n_freqs, n_times) | ||
|
||
# average over time | ||
conn_data = con.get_data(output='dense').mean(axis=-1) | ||
conn_data = conn_data.mean(axis=-1) | ||
(n_epochs, n_signals, n_signals, len(con.freqs)) | ||
|
||
# test the simulated signal | ||
triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T | ||
|
||
# average over frequencies | ||
conn_data = con.get_data(output='dense').mean(axis=-1) | ||
|
||
# the indices at which there is a correlation should be greater | ||
# then the rest of the components | ||
for epoch_idx in range(n_epochs): | ||
|
@@ -526,95 +611,6 @@ def test_spectral_connectivity_time_resolved(method, mode): | |
for idx, jdx in triu_inds) | ||
|
||
|
||
@pytest.mark.parametrize('method', ['coh', 'plv']) | ||
@pytest.mark.parametrize( | ||
'mode', ['morlet', 'multitaper']) | ||
def test_time_resolved_spectral_conn_regression(method, mode): | ||
"""Regression test against original implementation in Frites. | ||
|
||
To see how the test dataset was generated, see | ||
``benchmarks/single_epoch_conn.py``. | ||
""" | ||
test_file_path_str = str(_resource_path( | ||
'mne_connectivity.tests', | ||
f'data/test_frite_dataset_{mode}_{method}.npy')) | ||
test_conn = np.load(test_file_path_str) | ||
|
||
# paths to mne datasets - sample ECoG | ||
bids_root = mne.datasets.epilepsy_ecog.data_path() | ||
|
||
# first define the BIDS path and load in the dataset | ||
bids_path = BIDSPath(root=bids_root, subject='pt1', session='presurgery', | ||
task='ictal', datatype='ieeg', extension='.vhdr') | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
raw = read_raw_bids(bids_path=bids_path, verbose=False) | ||
line_freq = raw.info['line_freq'] | ||
|
||
# Pick only the ECoG channels, removing the ECG channels | ||
raw.pick_types(ecog=True) | ||
|
||
# drop bad channels | ||
raw.drop_channels(raw.info['bads']) | ||
|
||
# only pick the first three channels to lower RAM usage | ||
raw = raw.pick_channels(raw.ch_names[:3]) | ||
|
||
# Load the data | ||
raw.load_data() | ||
|
||
# Then we remove line frequency interference | ||
raw.notch_filter(line_freq) | ||
|
||
# crop data and then Epoch | ||
raw_copy = raw.copy() | ||
raw = raw.crop(tmin=0, tmax=4, include_tmax=False) | ||
epochs = make_fixed_length_epochs(raw=raw, duration=2., overlap=1.) | ||
|
||
###################################################################### | ||
# Perform basic test to match simulation data using time-resolved spec | ||
###################################################################### | ||
# compare data to original run using Frites | ||
freqs = [30, 90] | ||
|
||
# mode was renamed in mne-connectivity | ||
if mode == 'morlet': | ||
mode = 'cwt_morlet' | ||
conn = spectral_connectivity_time( | ||
epochs, freqs=freqs, n_jobs=1, method=method, mode=mode) | ||
|
||
# frites only stores the upper triangular parts of the raveled array | ||
row_triu_inds, col_triu_inds = np.triu_indices(len(raw.ch_names), k=1) | ||
conn_data = conn.get_data(output='dense')[ | ||
:, row_triu_inds, col_triu_inds, ...] | ||
assert_array_almost_equal(conn_data, test_conn) | ||
|
||
###################################################################### | ||
# Give varying set of frequency bands and frequencies to perform cWT | ||
###################################################################### | ||
raw = raw_copy.crop(tmin=0, tmax=10, include_tmax=False) | ||
ch_names = epochs.ch_names | ||
epochs = make_fixed_length_epochs(raw=raw, duration=5, overlap=0.) | ||
|
||
# sampling rate of my data | ||
sfreq = raw.info['sfreq'] | ||
|
||
# frequency bands of interest | ||
fois = np.array([[4, 8], [8, 12], [12, 16], [16, 32]]) | ||
|
||
# frequencies of Continuous Morlet Wavelet Transform | ||
freqs = np.arange(4., 32., 1) | ||
|
||
# compute coherence | ||
cohs = spectral_connectivity_time( | ||
epochs, names=None, method=method, indices=None, | ||
sfreq=sfreq, foi=fois, sm_times=0.5, sm_freqs=1, sm_kernel='hanning', | ||
mode=mode, mt_bandwidth=None, freqs=freqs, n_cycles=5) | ||
assert cohs.get_data(output='dense').shape == ( | ||
len(epochs), len(ch_names), len(ch_names), len(fois), len(epochs.times) | ||
) | ||
|
||
|
||
def test_save(tmp_path): | ||
"""Test saving results of spectral connectivity.""" | ||
rng = np.random.RandomState(0) | ||
|
Uh oh!
There was an error while loading. Please reload this page.