-
Notifications
You must be signed in to change notification settings - Fork 35
[MRG] [BUG] [ENH] [WIP] Bug fixes and enhancements for time-resolved spectral connectivity estimation #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
41bd87b
FIX: compute connectivity over multiple tapers when mode='multitaper'…
ruuskas 05ae8f6
require MNE-Python 1.0 or newer due to breaking changes in mne.time_f…
ruuskas 9cc283c
update tests corresponding to API changes and use longer test signal
ruuskas 924b4ab
update docstring: connectivity is not averaged over Epochs by default
ruuskas c5b75f8
fix docstring typo
ruuskas 7699983
update docstring: faverage is False by default
ruuskas 89f683a
update docstring: lower bound of frequency range for connectivity com…
ruuskas cb2dab5
update docstring: fmax may be None
ruuskas 0fd34c9
update docstring
ruuskas f4992a7
new default for fmin
ruuskas cc1d69d
Merge branch 'spectral_time' of github.com:ruuskas/mne-connectivity i…
ruuskas fa9bce7
improve docstring and warnings for spectral_connectivity_time
ruuskas 9e9f983
fix bug with indices: connectivity is now computed correctly between …
ruuskas 2e4ccd1
DOC: improve documentation
ruuskas 926c330
change smoothing default to no smoothing
ruuskas b881c34
DOC: updates to main docstring
ruuskas 73fb937
BUG: number of blocks is now computed correctly
ruuskas baaca26
add test for time-resolved connectivity with simulated data
ruuskas 2a7e09e
Change block_size default to 1
ruuskas 2a0be06
Add documentation for block_size
ruuskas 53ac7b5
Change for more useful variable names
ruuskas bc61e54
Remove regression test
ruuskas 7ce13bd
Remove block_size parameter
ruuskas c7dd18c
Improve documentation
ruuskas 4d2c1f0
Improve comments
ruuskas 1b6224f
Remove unused code
ruuskas 054512b
Merge branch 'main' into spectral_time
adam2392 283a1a1
Fix style issues
ruuskas 083aa1b
Merge branch 'spectral_time' of github.com:ruuskas/mne-connectivity i…
ruuskas e89ac77
Add comment
ruuskas 19603bd
Improve comment
ruuskas e5da3ae
Rename test function
ruuskas 39f0ee6
Update docstring
ruuskas 4b6311c
Add comments
ruuskas 7c64633
Merge branch 'spectral_time' of github.com:ruuskas/mne-connectivity i…
ruuskas f6684c0
DOC: Fix typos
ruuskas 39fef93
DOC: Improve doc formulation
ruuskas 08b5c79
DOC: Add note on memory mapping
ruuskas ccb0a2d
Remove unused names parameter
ruuskas fe727f7
Require sfreq with array input
ruuskas 3b966ec
DOC: Improve documentation
ruuskas f082d6f
Add test for cwt_freqs
ruuskas 84d073b
BUG: Fix spectral_connectivity time
ruuskas 3e7f208
Compute weighted average over CSD
ruuskas 71b61ad
Fix style
ruuskas 9e073c6
Update the docstring of spectral_connectivity_time
ruuskas dcfbc8b
Remove unnecessary defaults
ruuskas fb9869f
Add entries in whats_new.rst and authors.inc
ruuskas eec0113
FIX: Test
larsoner b2dda41
FIX: Doc build
larsoner 099f194
FIX: Not pre
larsoner File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,8 @@ | |
import xarray as xr | ||
from mne.epochs import BaseEpochs | ||
from mne.parallel import parallel_func | ||
from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper) | ||
from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper, | ||
dpss_windows) | ||
from mne.utils import (logger, warn) | ||
|
||
from ..base import (SpectralConnectivity, EpochSpectralConnectivity) | ||
|
@@ -410,11 +411,25 @@ def _spectral_connectivity(data, method, kernel, foi_idx, | |
data, sfreq, freqs, n_cycles=n_cycles, output='complex', | ||
decim=decim, n_jobs=n_jobs, **kw_cwt) | ||
out = np.expand_dims(out, axis=2) # same dims with multitaper | ||
weights = None | ||
elif mode == 'multitaper': | ||
out = tfr_array_multitaper( | ||
data, sfreq, freqs, n_cycles=n_cycles, | ||
time_bandwidth=mt_bandwidth, output='complex', decim=decim, | ||
n_jobs=n_jobs, **kw_mt) | ||
if isinstance(n_cycles, (int, float)): | ||
n_cycles = [n_cycles] * len(freqs) | ||
mt_bandwidth = mt_bandwidth if mt_bandwidth else 4 | ||
n_tapers = int(np.floor(mt_bandwidth - 1)) | ||
weights = np.zeros((n_tapers, len(freqs), out.shape[-1])) | ||
for i, (f, n_c) in enumerate(zip(freqs, n_cycles)): | ||
window_length = np.arange(0., n_c / float(f), 1.0 / sfreq).shape[0] | ||
half_nbw = mt_bandwidth / 2. | ||
n_tapers = int(np.floor(mt_bandwidth - 1)) | ||
_, eigvals = dpss_windows(window_length, half_nbw, n_tapers, | ||
sym=False) | ||
weights[:, i, :] = np.sqrt(eigvals[:, np.newaxis]) | ||
# weights have shape (n_tapers, n_freqs, n_times) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @larsoner, @drammock, @britta-wstnr how does this look to you? |
||
raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") | ||
|
||
|
@@ -427,9 +442,7 @@ def _spectral_connectivity(data, method, kernel, foi_idx, | |
this_conn[m] = c_func(out, kernel, foi_idx, source_idx, | ||
target_idx, n_jobs=n_jobs, | ||
verbose=verbose, total=n_pairs, | ||
faverage=faverage) | ||
# mean over tapers | ||
this_conn[m] = [c.mean(axis=1) for c in this_conn[m]] | ||
faverage=faverage, weights=weights) | ||
|
||
return this_conn | ||
|
||
|
@@ -441,22 +454,29 @@ def _spectral_connectivity(data, method, kernel, foi_idx, | |
############################################################################### | ||
|
||
def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, | ||
faverage): | ||
faverage, weights): | ||
"""Pairwise coherence. | ||
|
||
Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, | ||
n_times).""" | ||
# auto spectra (faster than w * w.conj()) | ||
s_auto = w.real ** 2 + w.imag ** 2 | ||
|
||
# smooth the auto spectra | ||
s_auto = _smooth_spectra(s_auto, kernel) | ||
if weights is not None: | ||
psd = weights * w | ||
psd = psd * np.conj(psd) | ||
psd = psd.real.sum(axis=2) | ||
psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0) | ||
else: | ||
psd = w.real ** 2 + w.imag ** 2 | ||
psd = np.squeeze(psd, axis=2) | ||
|
||
# smooth the psd | ||
psd = _smooth_spectra(psd, kernel) | ||
|
||
def pairwise_coh(w_x, w_y): | ||
s_xy = w[:, w_y] * np.conj(w[:, w_x]) | ||
s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) | ||
s_xy = _smooth_spectra(s_xy, kernel) | ||
s_xx = s_auto[:, w_x] | ||
s_yy = s_auto[:, w_y] | ||
s_xx = psd[:, w_x] | ||
s_yy = psd[:, w_y] | ||
out = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ | ||
np.sqrt(s_xx.mean(axis=-1, keepdims=True) * | ||
s_yy.mean(axis=-1, keepdims=True)) | ||
|
@@ -474,13 +494,13 @@ def pairwise_coh(w_x, w_y): | |
|
||
|
||
def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, | ||
faverage): | ||
faverage, weights): | ||
"""Pairwise phase-locking value. | ||
|
||
Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, | ||
n_times).""" | ||
def pairwise_plv(w_x, w_y): | ||
s_xy = w[:, w_y] * np.conj(w[:, w_x]) | ||
s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) | ||
exp_dphi = s_xy / np.abs(s_xy) | ||
exp_dphi = _smooth_spectra(exp_dphi, kernel) | ||
# mean over time | ||
|
@@ -500,13 +520,13 @@ def pairwise_plv(w_x, w_y): | |
|
||
|
||
def _pli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, | ||
faverage): | ||
faverage, weights): | ||
"""Pairwise phase-lag index. | ||
|
||
Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, | ||
n_times).""" | ||
def pairwise_pli(w_x, w_y): | ||
s_xy = w[:, w_y] * np.conj(w[:, w_x]) | ||
s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) | ||
s_xy = _smooth_spectra(s_xy, kernel) | ||
out = np.abs(np.mean(np.sign(np.imag(s_xy)), | ||
axis=-1, keepdims=True)) | ||
|
@@ -524,13 +544,13 @@ def pairwise_pli(w_x, w_y): | |
|
||
|
||
def _wpli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, | ||
faverage): | ||
faverage, weights): | ||
"""Pairwise weighted phase-lag index. | ||
|
||
Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, | ||
n_times).""" | ||
def pairwise_wpli(w_x, w_y): | ||
s_xy = w[:, w_y] * np.conj(w[:, w_x]) | ||
s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) | ||
s_xy = _smooth_spectra(s_xy, kernel) | ||
con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) | ||
con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) | ||
|
@@ -549,10 +569,10 @@ def pairwise_wpli(w_x, w_y): | |
|
||
|
||
def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, | ||
faverage): | ||
faverage, weights): | ||
"""Pairwise cross-spectra.""" | ||
def pairwise_cs(w_x, w_y): | ||
out = w[:, w_x] * np.conj(w[:, w_y]) | ||
out = _compute_csd(w[:, w_y], w[:, w_x], weights) | ||
out = _smooth_spectra(out, kernel) | ||
if isinstance(foi_idx, np.ndarray) and faverage: | ||
return _foi_average(out, foi_idx) | ||
|
@@ -566,6 +586,17 @@ def pairwise_cs(w_x, w_y): | |
return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) | ||
|
||
|
||
def _compute_csd(x, y, weights): | ||
"""Compute cross spectral density of signals x and y.""" | ||
if weights is not None: | ||
s_xy = np.sum(weights * x * np.conj(weights * y), axis=-3) | ||
s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=-3) | ||
else: | ||
s_xy = x * np.conj(y) | ||
s_xy = np.squeeze(s_xy, axis=-3) | ||
return s_xy | ||
|
||
|
||
def _foi_average(conn, foi_idx): | ||
"""Average inside frequency bands. | ||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add some comments perhaps in the docstring for future developers of what the weights are intended to doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this and tried to make the docstring better in general. For some reason, Sphinx now gives this error
I don't see any rogue inline literal start-strings and the HTML looks good. A quick Google search suggests there might be an issue with the configuration (or it's just me).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the changes into the PR and I can take a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added them already after posting the comment above.