Skip to content

[MRG] [BUG] [ENH] [WIP] Bug fixes and enhancements for time-resolved spectral connectivity estimation #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
Nov 21, 2022
Merged
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
41bd87b
FIX: compute connectivity over multiple tapers when mode='multitaper'…
ruuskas Aug 31, 2022
05ae8f6
require MNE-Python 1.0 or newer due to breaking changes in mne.time_f…
ruuskas Aug 31, 2022
9cc283c
update tests corresponding to API changes and use longer test signal
ruuskas Aug 31, 2022
924b4ab
update docstring: connectivity is not averaged over Epochs by default
ruuskas Sep 1, 2022
c5b75f8
fix docstring typo
ruuskas Sep 1, 2022
7699983
update docstring: faverage is False by default
ruuskas Sep 1, 2022
89f683a
update docstring: lower bound of frequency range for connectivity com…
ruuskas Sep 1, 2022
cb2dab5
update docstring: fmax may be None
ruuskas Sep 1, 2022
0fd34c9
update docstring
ruuskas Sep 2, 2022
f4992a7
new default for fmin
ruuskas Sep 2, 2022
cc1d69d
Merge branch 'spectral_time' of github.com:ruuskas/mne-connectivity i…
ruuskas Sep 2, 2022
fa9bce7
improve docstring and warnings for spectral_connectivity_time
ruuskas Sep 2, 2022
9e9f983
fix bug with indices: connectivity is now computed correctly between …
ruuskas Sep 5, 2022
2e4ccd1
DOC: improve documentation
ruuskas Sep 5, 2022
926c330
change smoothing default to no smoothing
ruuskas Sep 7, 2022
b881c34
DOC: updates to main docstring
ruuskas Sep 7, 2022
73fb937
BUG: number of blocks is now computed correctly
ruuskas Sep 7, 2022
baaca26
add test for time-resolved connectivity with simulated data
ruuskas Sep 12, 2022
2a7e09e
Change block_size default to 1
ruuskas Sep 22, 2022
2a0be06
Add documentation for block_size
ruuskas Sep 22, 2022
53ac7b5
Change for more useful variable names
ruuskas Sep 22, 2022
bc61e54
Remove regression test
ruuskas Oct 5, 2022
7ce13bd
Remove block_size parameter
ruuskas Oct 5, 2022
c7dd18c
Improve documentation
ruuskas Oct 5, 2022
4d2c1f0
Improve comments
ruuskas Oct 5, 2022
1b6224f
Remove unused code
ruuskas Oct 5, 2022
054512b
Merge branch 'main' into spectral_time
adam2392 Oct 16, 2022
283a1a1
Fix style issues
ruuskas Oct 17, 2022
083aa1b
Merge branch 'spectral_time' of github.com:ruuskas/mne-connectivity i…
ruuskas Oct 17, 2022
e89ac77
Add comment
ruuskas Oct 19, 2022
19603bd
Improve comment
ruuskas Oct 19, 2022
e5da3ae
Rename test function
ruuskas Oct 19, 2022
39f0ee6
Update docstring
ruuskas Oct 19, 2022
4b6311c
Add comments
ruuskas Oct 19, 2022
7c64633
Merge branch 'spectral_time' of github.com:ruuskas/mne-connectivity i…
ruuskas Oct 19, 2022
f6684c0
DOC: Fix typos
ruuskas Nov 9, 2022
39fef93
DOC: Improve doc formulation
ruuskas Nov 9, 2022
08b5c79
DOC: Add note on memory mapping
ruuskas Nov 9, 2022
ccb0a2d
Remove unused names parameter
ruuskas Nov 9, 2022
fe727f7
Require sfreq with array input
ruuskas Nov 9, 2022
3b966ec
DOC: Improve documentation
ruuskas Nov 9, 2022
f082d6f
Add test for cwt_freqs
ruuskas Nov 9, 2022
84d073b
BUG: Fix spectral_connectivity time
ruuskas Nov 9, 2022
3e7f208
Compute weighted average over CSD
ruuskas Nov 14, 2022
71b61ad
Fix style
ruuskas Nov 15, 2022
9e073c6
Update the docstring of spectral_connectivity_time
ruuskas Nov 15, 2022
dcfbc8b
Remove unnecessary defaults
ruuskas Nov 15, 2022
fb9869f
Add entries in whats_new.rst and authors.inc
ruuskas Nov 17, 2022
eec0113
FIX: Test
larsoner Nov 17, 2022
b2dda41
FIX: Doc build
larsoner Nov 17, 2022
099f194
FIX: Not pre
larsoner Nov 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 51 additions & 20 deletions mne_connectivity/spectral/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import xarray as xr
from mne.epochs import BaseEpochs
from mne.parallel import parallel_func
from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper)
from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper,
dpss_windows)
from mne.utils import (logger, warn)

from ..base import (SpectralConnectivity, EpochSpectralConnectivity)
Expand Down Expand Up @@ -410,11 +411,25 @@ def _spectral_connectivity(data, method, kernel, foi_idx,
data, sfreq, freqs, n_cycles=n_cycles, output='complex',
decim=decim, n_jobs=n_jobs, **kw_cwt)
out = np.expand_dims(out, axis=2) # same dims with multitaper
weights = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments perhaps in the docstring for future developers of what the weights are intended to doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this and tried to make the docstring better in general. For some reason, Sphinx now gives this error

mne-connectivity/mne_connectivity/spectral/time.py:docstring of mne_connectivity.spectral.time.spectral_connectivity_time:59: WARNING: Inline literal start-string without end-string.

I don't see any rogue inline literal start-strings and the HTML looks good. A quick Google search suggests there might be an issue with the configuration (or it's just me).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the changes into the PR and I can take a look.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added them already after posting the comment above.

elif mode == 'multitaper':
out = tfr_array_multitaper(
data, sfreq, freqs, n_cycles=n_cycles,
time_bandwidth=mt_bandwidth, output='complex', decim=decim,
n_jobs=n_jobs, **kw_mt)
if isinstance(n_cycles, (int, float)):
n_cycles = [n_cycles] * len(freqs)
mt_bandwidth = mt_bandwidth if mt_bandwidth else 4
n_tapers = int(np.floor(mt_bandwidth - 1))
weights = np.zeros((n_tapers, len(freqs), out.shape[-1]))
for i, (f, n_c) in enumerate(zip(freqs, n_cycles)):
window_length = np.arange(0., n_c / float(f), 1.0 / sfreq).shape[0]
half_nbw = mt_bandwidth / 2.
n_tapers = int(np.floor(mt_bandwidth - 1))
_, eigvals = dpss_windows(window_length, half_nbw, n_tapers,
sym=False)
weights[:, i, :] = np.sqrt(eigvals[:, np.newaxis])
# weights have shape (n_tapers, n_freqs, n_times)
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@larsoner, @drammock, @britta-wstnr how does this look to you?

raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.")

Expand All @@ -427,9 +442,7 @@ def _spectral_connectivity(data, method, kernel, foi_idx,
this_conn[m] = c_func(out, kernel, foi_idx, source_idx,
target_idx, n_jobs=n_jobs,
verbose=verbose, total=n_pairs,
faverage=faverage)
# mean over tapers
this_conn[m] = [c.mean(axis=1) for c in this_conn[m]]
faverage=faverage, weights=weights)

return this_conn

Expand All @@ -441,22 +454,29 @@ def _spectral_connectivity(data, method, kernel, foi_idx,
###############################################################################

def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total,
faverage):
faverage, weights):
"""Pairwise coherence.

Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs,
n_times)."""
# auto spectra (faster than w * w.conj())
s_auto = w.real ** 2 + w.imag ** 2

# smooth the auto spectra
s_auto = _smooth_spectra(s_auto, kernel)
if weights is not None:
psd = weights * w
psd = psd * np.conj(psd)
psd = psd.real.sum(axis=2)
psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0)
else:
psd = w.real ** 2 + w.imag ** 2
psd = np.squeeze(psd, axis=2)

# smooth the psd
psd = _smooth_spectra(psd, kernel)

def pairwise_coh(w_x, w_y):
s_xy = w[:, w_y] * np.conj(w[:, w_x])
s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights)
s_xy = _smooth_spectra(s_xy, kernel)
s_xx = s_auto[:, w_x]
s_yy = s_auto[:, w_y]
s_xx = psd[:, w_x]
s_yy = psd[:, w_y]
out = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \
np.sqrt(s_xx.mean(axis=-1, keepdims=True) *
s_yy.mean(axis=-1, keepdims=True))
Expand All @@ -474,13 +494,13 @@ def pairwise_coh(w_x, w_y):


def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total,
faverage):
faverage, weights):
"""Pairwise phase-locking value.

Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs,
n_times)."""
def pairwise_plv(w_x, w_y):
s_xy = w[:, w_y] * np.conj(w[:, w_x])
s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights)
exp_dphi = s_xy / np.abs(s_xy)
exp_dphi = _smooth_spectra(exp_dphi, kernel)
# mean over time
Expand All @@ -500,13 +520,13 @@ def pairwise_plv(w_x, w_y):


def _pli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total,
faverage):
faverage, weights):
"""Pairwise phase-lag index.

Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs,
n_times)."""
def pairwise_pli(w_x, w_y):
s_xy = w[:, w_y] * np.conj(w[:, w_x])
s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights)
s_xy = _smooth_spectra(s_xy, kernel)
out = np.abs(np.mean(np.sign(np.imag(s_xy)),
axis=-1, keepdims=True))
Expand All @@ -524,13 +544,13 @@ def pairwise_pli(w_x, w_y):


def _wpli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total,
faverage):
faverage, weights):
"""Pairwise weighted phase-lag index.

Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs,
n_times)."""
def pairwise_wpli(w_x, w_y):
s_xy = w[:, w_y] * np.conj(w[:, w_x])
s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights)
s_xy = _smooth_spectra(s_xy, kernel)
con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True))
con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True)
Expand All @@ -549,10 +569,10 @@ def pairwise_wpli(w_x, w_y):


def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total,
faverage):
faverage, weights):
"""Pairwise cross-spectra."""
def pairwise_cs(w_x, w_y):
out = w[:, w_x] * np.conj(w[:, w_y])
out = _compute_csd(w[:, w_y], w[:, w_x], weights)
out = _smooth_spectra(out, kernel)
if isinstance(foi_idx, np.ndarray) and faverage:
return _foi_average(out, foi_idx)
Expand All @@ -566,6 +586,17 @@ def pairwise_cs(w_x, w_y):
return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx))


def _compute_csd(x, y, weights):
"""Compute cross spectral density of signals x and y."""
if weights is not None:
s_xy = np.sum(weights * x * np.conj(weights * y), axis=-3)
s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=-3)
else:
s_xy = x * np.conj(y)
s_xy = np.squeeze(s_xy, axis=-3)
return s_xy


def _foi_average(conn, foi_idx):
"""Average inside frequency bands.

Expand Down