Skip to content

Commit ed40aef

Browse files
ruuskasadam2392
andauthored
[ENH] Multiple improvements to spectral_connectivity_time: ciPLV, and efficient computation of multiple metrics (#115)
* Add ciPLV: Add the corrected imaginary Phase-Locking-Value into the list of available connectivity metrics. * Speed up computation: All connectivity measures are now computed with only a single computation of pairwise cross spectrum. * Add the option to specify freqs in all modes: In some scenarios, users might want to specify the frequencies for time-frequency decomposition also when using multitapering. These changes allow users to specify the 'freqs' parameter to override the automatically determined frequencies. * BUG: Average over CSD instead of connectivity * Add option to use part of signal as padding: This adds the option to use the edges of the signal at each epoch as padding. The purpose of this is to avoid edge effects generated by the time-frequency transformation methods. * Fix test bug, use 'freqs' instead of 'cwt_freqs' * Fix bug with dpss windows: Sym is not a parameter of dpss_windows. (But is one of the underlying scipy.signal.dpss) * Only show progress bar if verbosity level is DEBUG: This change will skip the rendering of the connectivity computation progress bar if the logging level is not DEBUG. This is in line with MNE-Python, where progress bars are not shown at INFO or higher logging levels. Rendering the progress bar regardless of logging levels has the potential to cause unnecessary clutter in users' log files. * Require freqs in all tfr modes The user is required to specify the wavelet central frequencies in both multitaper and cwt_morlet tfr mode. The reasoning is that the underlying tfr implementations are very similar. This is in contrast to spectral_connectivity_epochs, where multitaper assumes that the spectrum is stationary and therefore no wavelets are used. * Require mne>=1.3 Signed-off-by: Adam Li <adam2392@gmail.com> Co-authored-by: Adam Li <adam2392@gmail.com>
1 parent 3901744 commit ed40aef

File tree

4 files changed

+350
-207
lines changed

4 files changed

+350
-207
lines changed

doc/whats_new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ Enhancements
2929
- Improve the documentation of :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
3030
- Add the option to average connectivity across epochs and frequencies in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
3131
- Select multitaper frequencies automatically in :func:`mne_connectivity.spectral_connectivity_time` similarly to :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`).
32+
- Add the ``ciPLV`` method in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`).
33+
- Add the option to use the edges of each epoch as padding in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`).
3234

3335
Bug
3436
~~~

mne_connectivity/spectral/tests/test_spectral.py

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def test_epochs_tmin_tmax(kind):
472472
assert len(w) == 1 # just one even though there were multiple epochs
473473

474474

475-
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
475+
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv'])
476476
@pytest.mark.parametrize(
477477
'mode', ['cwt_morlet', 'multitaper'])
478478
@pytest.mark.parametrize('data_option', ['sync', 'random'])
@@ -504,11 +504,11 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option):
504504
# hypothesized "connection"
505505
freq_band_low_limit = (8.)
506506
freq_band_high_limit = (13.)
507-
cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1)
508-
con = spectral_connectivity_time(data, method=method, mode=mode,
507+
freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1)
508+
con = spectral_connectivity_time(data, freqs, method=method, mode=mode,
509509
sfreq=sfreq, fmin=freq_band_low_limit,
510510
fmax=freq_band_high_limit,
511-
cwt_freqs=cwt_freqs, n_jobs=1,
511+
n_jobs=1,
512512
faverage=True, average=True, sm_times=0)
513513
assert con.shape == (n_channels ** 2, len(con.freqs))
514514
con_matrix = con.get_data('dense')[..., 0]
@@ -526,12 +526,13 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option):
526526
assert np.all(con_matrix) <= 0.5
527527

528528

529-
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
529+
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv'])
530530
@pytest.mark.parametrize(
531-
'cwt_freqs', [[8., 10.], [8, 10], 10., 10])
532-
def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs):
531+
'freqs', [[8., 10.], [8, 10], 10., 10])
532+
@pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper'])
533+
def test_spectral_connectivity_time_freqs(method, freqs, mode):
533534
"""Test time-resolved spectral connectivity with int and float values for
534-
cwt_freqs."""
535+
freqs."""
535536
rng = np.random.default_rng(0)
536537
n_epochs = 5
537538
n_channels = 3
@@ -552,10 +553,10 @@ def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs):
552553
data[i, c] = np.squeeze(np.sin(x))
553554
# the frequency band should contain the frequency at which there is a
554555
# hypothesized "connection"
555-
con = spectral_connectivity_time(data, method=method, mode='cwt_morlet',
556-
sfreq=sfreq, fmin=np.min(cwt_freqs),
557-
fmax=np.max(cwt_freqs),
558-
cwt_freqs=cwt_freqs, n_jobs=1,
556+
con = spectral_connectivity_time(data, freqs, method=method,
557+
mode=mode, sfreq=sfreq,
558+
fmin=np.min(freqs),
559+
fmax=np.max(freqs), n_jobs=1,
559560
faverage=True, average=True, sm_times=0)
560561
assert con.shape == (n_channels ** 2, len(con.freqs))
561562
con_matrix = con.get_data('dense')[..., 0]
@@ -588,12 +589,12 @@ def test_spectral_connectivity_time_resolved(method, mode):
588589
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
589590
data = EpochsArray(data, info)
590591

591-
# define some frequencies for cwt
592+
# define some frequencies for tfr
592593
freqs = np.arange(3, 20.5, 1)
593594

594595
# run connectivity estimation
595596
con = spectral_connectivity_time(
596-
data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode,
597+
data, freqs, sfreq=sfreq, method=method, mode=mode,
597598
n_cycles=5)
598599
assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs))
599600
assert con.get_data(output='dense').shape == \
@@ -613,6 +614,63 @@ def test_spectral_connectivity_time_resolved(method, mode):
613614
for idx, jdx in triu_inds)
614615

615616

617+
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
618+
@pytest.mark.parametrize(
619+
'mode', ['cwt_morlet', 'multitaper'])
620+
@pytest.mark.parametrize('padding', [0, 1, 5])
621+
def test_spectral_connectivity_time_padding(method, mode, padding):
622+
"""Test time-resolved spectral connectivity with padding."""
623+
sfreq = 50.
624+
n_signals = 3
625+
n_epochs = 2
626+
n_times = 300
627+
trans_bandwidth = 2.
628+
tmin = 0.
629+
tmax = (n_times - 1) / sfreq
630+
# 5Hz..15Hz
631+
fstart, fend = 5.0, 15.0
632+
data, _ = create_test_dataset(
633+
sfreq, n_signals=n_signals, n_epochs=n_epochs, n_times=n_times,
634+
tmin=tmin, tmax=tmax,
635+
fstart=fstart, fend=fend, trans_bandwidth=trans_bandwidth)
636+
ch_names = np.arange(n_signals).astype(str).tolist()
637+
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
638+
data = EpochsArray(data, info)
639+
640+
# define some frequencies for tfr
641+
freqs = np.arange(3, 20.5, 1)
642+
643+
# run connectivity estimation
644+
if padding == 5:
645+
with pytest.raises(ValueError, match='Padding cannot be larger than '
646+
'half of data length'):
647+
con = spectral_connectivity_time(
648+
data, freqs, sfreq=sfreq, method=method, mode=mode,
649+
n_cycles=5, padding=padding)
650+
return
651+
else:
652+
con = spectral_connectivity_time(
653+
data, freqs, sfreq=sfreq, method=method, mode=mode,
654+
n_cycles=5, padding=padding)
655+
656+
assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs))
657+
assert con.get_data(output='dense').shape == \
658+
(n_epochs, n_signals, n_signals, len(con.freqs))
659+
660+
# test the simulated signal
661+
triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T
662+
663+
# average over frequencies
664+
conn_data = con.get_data(output='dense').mean(axis=-1)
665+
666+
# the indices at which there is a correlation should be greater
667+
# then the rest of the components
668+
for epoch_idx in range(n_epochs):
669+
high_conn_val = conn_data[epoch_idx, 0, 1]
670+
assert all(high_conn_val >= conn_data[epoch_idx, idx, jdx]
671+
for idx, jdx in triu_inds)
672+
673+
616674
def test_save(tmp_path):
617675
"""Test saving results of spectral connectivity."""
618676
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)