diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 503ab0af0..758426d23 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -29,6 +29,8 @@ Enhancements - Improve the documentation of :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). - Add the option to average connectivity across epochs and frequencies in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). - Select multitaper frequencies automatically in :func:`mne_connectivity.spectral_connectivity_time` similarly to :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`). +- Add the ``ciPLV`` method in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`). +- Add the option to use the edges of each epoch as padding in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`). Bug ~~~ diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 8b4c71a83..3286bba88 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -472,7 +472,7 @@ def test_epochs_tmin_tmax(kind): assert len(w) == 1 # just one even though there were multiple epochs -@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize( 'mode', ['cwt_morlet', 'multitaper']) @pytest.mark.parametrize('data_option', ['sync', 'random']) @@ -504,11 +504,11 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): # hypothesized "connection" freq_band_low_limit = (8.) freq_band_high_limit = (13.) - cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) - con = spectral_connectivity_time(data, method=method, mode=mode, + freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) + con = spectral_connectivity_time(data, freqs, method=method, mode=mode, sfreq=sfreq, fmin=freq_band_low_limit, fmax=freq_band_high_limit, - cwt_freqs=cwt_freqs, n_jobs=1, + n_jobs=1, faverage=True, average=True, sm_times=0) assert con.shape == (n_channels ** 2, len(con.freqs)) con_matrix = con.get_data('dense')[..., 0] @@ -526,12 +526,13 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): assert np.all(con_matrix) <= 0.5 -@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize( - 'cwt_freqs', [[8., 10.], [8, 10], 10., 10]) -def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): + 'freqs', [[8., 10.], [8, 10], 10., 10]) +@pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper']) +def test_spectral_connectivity_time_freqs(method, freqs, mode): """Test time-resolved spectral connectivity with int and float values for - cwt_freqs.""" + freqs.""" rng = np.random.default_rng(0) n_epochs = 5 n_channels = 3 @@ -552,10 +553,10 @@ def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): data[i, c] = np.squeeze(np.sin(x)) # the frequency band should contain the frequency at which there is a # hypothesized "connection" - con = spectral_connectivity_time(data, method=method, mode='cwt_morlet', - sfreq=sfreq, fmin=np.min(cwt_freqs), - fmax=np.max(cwt_freqs), - cwt_freqs=cwt_freqs, n_jobs=1, + con = spectral_connectivity_time(data, freqs, method=method, + mode=mode, sfreq=sfreq, + fmin=np.min(freqs), + fmax=np.max(freqs), n_jobs=1, faverage=True, average=True, sm_times=0) assert con.shape == (n_channels ** 2, len(con.freqs)) con_matrix = con.get_data('dense')[..., 0] @@ -588,12 +589,12 @@ def test_spectral_connectivity_time_resolved(method, mode): info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg') data = EpochsArray(data, info) - # define some frequencies for cwt + # define some frequencies for tfr freqs = np.arange(3, 20.5, 1) # run connectivity estimation con = spectral_connectivity_time( - data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode, + data, freqs, sfreq=sfreq, method=method, mode=mode, n_cycles=5) assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) assert con.get_data(output='dense').shape == \ @@ -613,6 +614,63 @@ def test_spectral_connectivity_time_resolved(method, mode): for idx, jdx in triu_inds) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize( + 'mode', ['cwt_morlet', 'multitaper']) +@pytest.mark.parametrize('padding', [0, 1, 5]) +def test_spectral_connectivity_time_padding(method, mode, padding): + """Test time-resolved spectral connectivity with padding.""" + sfreq = 50. + n_signals = 3 + n_epochs = 2 + n_times = 300 + trans_bandwidth = 2. + tmin = 0. + tmax = (n_times - 1) / sfreq + # 5Hz..15Hz + fstart, fend = 5.0, 15.0 + data, _ = create_test_dataset( + sfreq, n_signals=n_signals, n_epochs=n_epochs, n_times=n_times, + tmin=tmin, tmax=tmax, + fstart=fstart, fend=fend, trans_bandwidth=trans_bandwidth) + ch_names = np.arange(n_signals).astype(str).tolist() + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg') + data = EpochsArray(data, info) + + # define some frequencies for tfr + freqs = np.arange(3, 20.5, 1) + + # run connectivity estimation + if padding == 5: + with pytest.raises(ValueError, match='Padding cannot be larger than ' + 'half of data length'): + con = spectral_connectivity_time( + data, freqs, sfreq=sfreq, method=method, mode=mode, + n_cycles=5, padding=padding) + return + else: + con = spectral_connectivity_time( + data, freqs, sfreq=sfreq, method=method, mode=mode, + n_cycles=5, padding=padding) + + assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) + assert con.get_data(output='dense').shape == \ + (n_epochs, n_signals, n_signals, len(con.freqs)) + + # test the simulated signal + triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T + + # average over frequencies + conn_data = con.get_data(output='dense').mean(axis=-1) + + # the indices at which there is a correlation should be greater + # then the rest of the components + for epoch_idx in range(n_epochs): + high_conn_val = conn_data[epoch_idx, 0, 1] + assert all(high_conn_val >= conn_data[epoch_idx, idx, jdx] + for idx, jdx in triu_inds) + + def test_save(tmp_path): """Test saving results of spectral connectivity.""" rng = np.random.RandomState(0) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index d5262667f..8658814bf 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -9,46 +9,49 @@ from mne.parallel import parallel_func from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper, dpss_windows) -from mne.utils import (logger, warn, verbose) +from mne.utils import (logger, verbose) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) -from .epochs import _compute_freqs, _compute_freq_mask +from .epochs import _compute_freq_mask from .smooth import _create_kernel, _smooth_spectra from ..utils import check_indices, fill_doc @verbose @fill_doc -def spectral_connectivity_time(data, method='coh', average=False, +def spectral_connectivity_time(data, freqs, method='coh', average=False, indices=None, sfreq=None, fmin=None, fmax=None, fskip=0, faverage=False, sm_times=0, - sm_freqs=1, sm_kernel='hanning', + sm_freqs=1, sm_kernel='hanning', padding=0, mode='cwt_morlet', mt_bandwidth=None, - cwt_freqs=None, n_cycles=7, decim=1, - n_jobs=1, verbose=None): - """Compute frequency- and time-frequency-domain connectivity measures. + n_cycles=7, decim=1, n_jobs=1, verbose=None): + """Compute time-frequency-domain connectivity measures. - This method computes time-resolved connectivity measures from epoched data. + This function computes spectral connectivity over time from epoched data. + The data may consist of a single epoch. The connectivity method(s) are specified using the ``method`` parameter. - All methods are based on estimates of the cross- and power spectral - densities (CSD/PSD) Sxy and Sxx, Syy. + All methods are based on time-resolved estimates of the cross- and + power spectral densities (CSD/PSD) Sxy and Sxx, Syy. Parameters ---------- data : array_like, shape (n_epochs, n_signals, n_times) | Epochs The data from which to compute connectivity. + freqs : array_like + Array of frequencies of interest for time-frequency decomposition. + Only the frequencies within the range specified by ``fmin`` and + ``fmax`` are used. method : str | list of str Connectivity measure(s) to compute. These can be - ``['coh', 'plv', 'sxy', 'pli', 'wpli']``. These are: - - * 'coh' : Coherence - * 'plv' : Phase-Locking Value (PLV) - * 'sxy' : Cross-spectrum - * 'pli' : Phase-Lag Index - * 'wpli': Weighted Phase-Lag Index + ``['coh', 'plv', 'ciplv', 'pli', 'wpli']``. These are: + * 'coh' : Coherence + * 'plv' : Phase-Locking Value (PLV) + * 'ciplv' : Corrected imaginary Phase-Locking Value + * 'pli' : Phase-Lag Index + * 'wpli' : Weighted Phase-Lag Index average : bool - Average connectivity scores over epochs. If True, output will be + Average connectivity scores over epochs. If ``True``, output will be an instance of :class:`SpectralConnectivity`, otherwise :class:`EpochSpectralConnectivity`. indices : tuple of array_like | None @@ -61,12 +64,11 @@ def spectral_connectivity_time(data, method='coh', average=False, fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using a tuple, e.g., ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower - bounds. If `None`, the frequency corresponding to an epoch length of - 5 cycles is used. + bounds. If `None`, the lowest frequency in ``freqs`` is used. fmax : float | tuple of float | None The upper frequency of interest. Multiple bands are defined using a tuple, e.g. ``(13., 30.)`` for two band with 13 Hz and 30 Hz upper - bounds. If `None`, ``sfreq/2`` is used. + bounds. If `None`, the highest frequency in ``freqs`` is used. fskip : int Omit every ``(fskip + 1)``-th frequency bin to decimate in frequency domain. @@ -82,6 +84,9 @@ def spectral_connectivity_time(data, method='coh', average=False, is equivalent to no smoothing. sm_kernel : {'square', 'hanning'} Smoothing kernel type. Choose either 'square' or 'hanning'. + padding : float + Amount of time to consider as padding at the beginning and end of each + epoch in seconds. See Notes for more information. mode : str Time-frequency decomposition method. Can be either: 'multitaper', or 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and @@ -93,11 +98,6 @@ def spectral_connectivity_time(data, method='coh', average=False, bandwidth (thus the frequency resolution) and the number of good tapers. See :func:`mne.time_frequency.tfr_array_multitaper` documentation. - cwt_freqs : array_like - Array of frequencies of interest for time-frequency decomposition. - Only used in 'cwt_morlet' mode. Only the frequencies within - the range specified by ``fmin`` and ``fmax`` are used. Required if - ``mode='cwt_morlet'``. Not used when ``mode='multitaper'``. n_cycles : float | array_like of float Number of cycles in the wavelet, either a fixed number or one per frequency. The number of cycles ``n_cycles`` and the frequencies of @@ -150,6 +150,14 @@ def spectral_connectivity_time(data, method='coh', average=False, using a weighted average, where the weights correspond to the concentration ratios between the DPSS windows. + Spectral estimation using multitaper or Morlet wavelets introduces edge + effects that depend on the length of the wavelet. To remove edge effects, + the parameter ``padding`` can be used to prune the edges of the signal. + Please see the documentation of + :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` for details on wavelet length + (i.e., time window length). + By default, the connectivity between all signals is computed (only connections corresponding to the lower-triangular part of the connectivity matrix). If one is only interested in the connectivity @@ -184,7 +192,12 @@ def spectral_connectivity_time(data, method='coh', average=False, PLV = |E[Sxy/|Sxy|]| - 'sxy' : Cross spectrum Sxy + 'ciplv' : Corrected imaginary PLV (icPLV) :footcite:`BrunaEtAl2018` + given by:: + + |E[Im(Sxy/|Sxy|)]| + ciPLV = ------------------------------------ + sqrt(1 - |E[real(Sxy/|Sxy|)]| ** 2) 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: @@ -256,25 +269,13 @@ def spectral_connectivity_time(data, method='coh', average=False, if isinstance(method, str): method = [method] - # check that fmin corresponds to at least 5 cycles - dur = float(n_times) / sfreq - five_cycle_freq = 5. / dur + # defaults for fmin and fmax if fmin is None: - # use the 5 cycle freq. as default - fmin = five_cycle_freq - logger.info(f'Fmin was not specified. Using fmin={fmin:.2f}, which ' - 'corresponds to at least five cycles.') - else: - if np.any(fmin < five_cycle_freq): - warn('fmin=%0.3f Hz corresponds to %0.3f < 5 cycles ' - 'based on the epoch length %0.3f sec, need at least %0.3f ' - 'sec epochs or fmin=%0.3f. Spectrum estimate will be ' - 'unreliable.' % (np.min(fmin), dur * np.min(fmin), dur, - 5. / np.min(fmin), five_cycle_freq)) + fmin = np.min(freqs) + logger.info('Fmin was not specified. Using fmin=min(freqs)') if fmax is None: - fmax = sfreq / 2 - logger.info(f'Fmax was not specified. Using fmax={fmax:.2f}, which ' - f'corresponds to Nyquist.') + fmax = np.max(freqs) + logger.info('Fmax was not specified. Using fmax=max(freqs).') fmin = np.array((fmin,), dtype=float).ravel() fmax = np.array((fmax,), dtype=float).ravel() @@ -308,24 +309,30 @@ def spectral_connectivity_time(data, method='coh', average=False, target_idx = indices_use[1] n_pairs = len(source_idx) - # check cwt_freqs - if cwt_freqs is not None: - # check for single frequency - if isinstance(cwt_freqs, (int, float)): - cwt_freqs = [cwt_freqs] - # array conversion - cwt_freqs = np.asarray(cwt_freqs) - # check order for multiple frequencies - if len(cwt_freqs) >= 2: - delta_f = np.diff(cwt_freqs) - increase = np.all(delta_f > 0) - assert increase, "Frequencies should be in increasing order" - - # compute frequencies to analyze based on number of samples, - # sampling rate, specified wavelet frequencies and mode - freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode) - - # compute the mask based on specified min/max and decimation factor + # check freqs + if isinstance(freqs, (int, float)): + freqs = [freqs] + # array conversion + freqs = np.asarray(freqs) + # check order for multiple frequencies + if len(freqs) >= 2: + delta_f = np.diff(freqs) + increase = np.all(delta_f > 0) + assert increase, "Frequencies should be in increasing order" + + # check that freqs corresponds to at least n_cycles cycles + dur = float(n_times) / sfreq + cycle_freq = n_cycles / dur + if np.any(freqs < cycle_freq): + raise ValueError('At least one value in n_cycles corresponds to a' + 'wavelet longer than the signal. Use less cycles, ' + 'higher frequencies, or longer epochs.') + # check for Nyquist + if np.any(freqs > sfreq / 2): + raise ValueError(f'Frequencies {freqs[freqs > sfreq / 2]} Hz are ' + f'larger than Nyquist = {sfreq / 2:.2f} Hz') + + # compute frequency mask based on specified min/max and decimation factor freq_mask = _compute_freq_mask(freqs, fmin, fmax, fskip) # the frequency points where we compute connectivity @@ -357,15 +364,14 @@ def spectral_connectivity_time(data, method='coh', average=False, source_idx=source_idx, target_idx=target_idx, mode=mode, sfreq=sfreq, freqs=freqs, faverage=faverage, n_cycles=n_cycles, mt_bandwidth=mt_bandwidth, - decim=decim, kw_cwt={}, kw_mt={}, n_jobs=n_jobs, + decim=decim, padding=padding, kw_cwt={}, kw_mt={}, n_jobs=n_jobs, verbose=verbose) for epoch_idx in np.arange(n_epochs): - epoch_idx = [epoch_idx] - conn_tr = _spectral_connectivity(data[epoch_idx, ...], **call_params) + logger.info(f' Processing epoch {epoch_idx+1} / {n_epochs} ...') + conn_tr = _spectral_connectivity(data[epoch_idx], **call_params) for m in method: - conn[m][epoch_idx, ...] = np.stack(conn_tr[m], - axis=1).squeeze(axis=-1) + conn[m][epoch_idx] = np.stack(conn_tr[m], axis=0) if indices is None: conn_flat = conn @@ -374,7 +380,7 @@ def spectral_connectivity_time(data, method='coh', average=False, this_conn = np.zeros((n_epochs, n_signals, n_signals) + conn_flat[m].shape[2:], dtype=conn_flat[m].dtype) - this_conn[:, source_idx, target_idx] = conn_flat[m][:, ...] + this_conn[:, source_idx, target_idx] = conn_flat[m] this_conn = this_conn.reshape((n_epochs, n_signals ** 2,) + conn_flat[m].shape[2:]) conn[m] = this_conn @@ -404,13 +410,55 @@ def spectral_connectivity_time(data, method='coh', average=False, def _spectral_connectivity(data, method, kernel, foi_idx, source_idx, target_idx, mode, sfreq, freqs, faverage, n_cycles, - mt_bandwidth, decim, kw_cwt, kw_mt, + mt_bandwidth, decim, padding, kw_cwt, kw_mt, n_jobs, verbose): """Estimate time-resolved connectivity for one epoch. - See spectral_connectivity_epochs.""" - n_pairs = len(source_idx) + Parameters + ---------- + data : array_like, shape (n_channels, n_times) + Time-series data. + method : list of str + List of connectivity metrics to compute. + kernel : array_like, shape (n_sm_fres, n_sm_times) + Smoothing kernel. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower bound indices of frequency bands. + source_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``target_idx``. + target_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``source_idx``. + mode : str + Time-frequency transformation method. + sfreq : float + Sampling frequency. + freqs : array_like + Array of frequencies of interest for time-frequency decomposition. + Only the frequencies within the range specified by ``fmin`` and + ``fmax`` are used. + faverage : bool + Average over frequency bands. + n_cycles : float | array_like of float + Number of cycles in the wavelet, either a fixed number or one per + frequency. + mt_bandwidth : float | None + Multitaper time-bandwidth. + decim : int + Decimation factor after time-frequency + decomposition. + padding : float + Amount of time to consider as padding at the beginning and end of each + epoch in seconds. + Returns + ------- + this_conn : list of array + List of connectivity estimates corresponding to the metrics in + ``method``. Each element is an array of shape (n_pairs, n_freqs) or + (n_pairs, n_fbands) if ``faverage`` is `True`. + """ + n_pairs = len(source_idx) + data = np.expand_dims(data, axis=0) if mode == 'cwt_morlet': out = tfr_array_morlet( data, sfreq, freqs, n_cycles=n_cycles, output='complex', @@ -438,16 +486,25 @@ def _spectral_connectivity(data, method, kernel, foi_idx, else: raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") + out = np.squeeze(out, axis=0) + + if padding: + if padding < 0: + raise ValueError(f'Padding cannot be negative, got {padding}.') + if padding >= data.shape[-1] / sfreq / 2: + raise ValueError(f'Padding cannot be larger than half of data ' + f'length, got {padding}.') + pad_idx = int(np.floor(padding * sfreq / decim)) + out = out[..., pad_idx:-pad_idx] + weights = weights[..., pad_idx:-pad_idx] if weights is not None \ + else None + # compute for each connectivity method this_conn = {} - conn_func = {'coh': _coh, 'plv': _plv, 'sxy': _cs, 'pli': _pli, - 'wpli': _wpli} - for m in method: - c_func = conn_func[m] - this_conn[m] = c_func(out, kernel, foi_idx, source_idx, - target_idx, n_jobs=n_jobs, - verbose=verbose, total=n_pairs, - faverage=faverage, weights=weights) + conn = _parallel_con(out, method, kernel, foi_idx, source_idx, target_idx, + n_jobs, verbose, n_pairs, faverage, weights) + for i, m in enumerate(method): + this_conn[m] = [out[i] for out in conn] return this_conn @@ -458,141 +515,167 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ############################################################################### ############################################################################### -def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise coherence. +def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, + verbose, total, faverage, weights): + """Compute spectral connectivity in parallel. - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" + Parameters + ---------- + w : array_like, shape (n_chans, n_tapers, n_freqs, n_times) + Time-frequency data (complex signal). + method : list of str + List of connectivity metrics to compute. + kernel : array_like, shape (n_sm_fres, n_sm_times) + Smoothing kernel. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower bound indices of frequency bands. + source_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``target_idx``. + target_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``source_idx``. + n_jobs : int + Number of parallel jobs. + total : int + Number of pairs of signals. + faverage : bool + Average over frequency bands. + weights : array_like, shape (n_tapers, n_freqs, n_times) + Multitaper weights. - if weights is not None: - psd = weights * w - psd = psd * np.conj(psd) - psd = psd.real.sum(axis=2) - psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0) - else: - psd = w.real ** 2 + w.imag ** 2 - psd = np.squeeze(psd, axis=2) - - # smooth the psd - psd = _smooth_spectra(psd, kernel) - - def pairwise_coh(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - s_xy = _smooth_spectra(s_xy, kernel) - s_xx = psd[:, w_x] - s_yy = psd[:, w_y] - out = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ - np.sqrt(s_xx.mean(axis=-1, keepdims=True) * - s_yy.mean(axis=-1, keepdims=True)) - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) + Returns + ------- + out : array_like, shape (n_pairs, n_methods, n_freqs_out) + Connectivity estimates for each signal pair, method, and frequency or + frequency band. + """ + if 'coh' in method: + # psd + if weights is not None: + psd = weights * w + psd = psd * np.conj(psd) + psd = psd.real.sum(axis=1) + psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0) else: - return out - - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_coh, n_jobs=n_jobs, verbose=verbose, total=total) + psd = w.real ** 2 + w.imag ** 2 + psd = np.squeeze(psd, axis=1) - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) + # smooth + psd = _smooth_spectra(psd, kernel) + else: + psd = None + # only show progress if verbosity level is DEBUG + if verbose != 'DEBUG' and verbose != 'debug' and verbose != 10: + total = None -def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise phase-locking value. + # define the function to compute in parallel + parallel, my_pairwise_con, n_jobs = parallel_func( + _pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total) - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" - def pairwise_plv(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - exp_dphi = s_xy / np.abs(s_xy) - exp_dphi = _smooth_spectra(exp_dphi, kernel) - # mean over time - exp_dphi_mean = exp_dphi.mean(axis=-1, keepdims=True) - out = np.abs(exp_dphi_mean) - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out + return parallel( + my_pairwise_con(w, psd, s, t, method, kernel, + foi_idx, faverage, weights) + for s, t in zip(source_idx, target_idx)) - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_plv, n_jobs=n_jobs, verbose=verbose, total=total) - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, + faverage, weights): + """Compute spectral connectivity metrics between two signals. + Parameters + ---------- + w : array_like, shape (n_chans, n_tapers, n_freqs, n_times) + Time-frequency data. + psd : array_like, shape (n_chans, n_freqs, n_times) + Power spectrum between signals ``x`` and ``y``. + x : int + Channel index. + y : int + Channel index. + method : str + Connectivity method. + kernel : array_like, shape (n_sm_fres, n_sm_times) + Smoothing kernel. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower bound indices of frequency bands. + faverage : bool + Average over frequency bands. + weights : array_like, shape (n_tapers, n_freqs, n_times) | None + Multitaper weights. -def _pli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise phase-lag index. + Returns + ------- + out : list + List of connectivity estimates between signals ``x`` and ``y`` + corresponding to the methods in ``method``. Each element is an array + with shape (n_freqs,) or (n_fbands) depending on ``faverage``. + """ + w_x, w_y = w[x], w[y] + if weights is not None: + s_xy = np.sum(weights * w_x * np.conj(weights * w_y), axis=0) + s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=0) + else: + s_xy = w_x * np.conj(w_y) + s_xy = np.squeeze(s_xy, axis=0) + s_xy = _smooth_spectra(s_xy, kernel) + out = [] + conn_func = {'plv': _plv, 'ciplv': _ciplv, 'pli': _pli, 'wpli': _wpli, + 'coh': _coh} + for m in method: + if m == 'coh': + s_xx = psd[x] + s_yy = psd[y] + out.append(conn_func[m](s_xx, s_yy, s_xy)) + else: + out.append(conn_func[m](s_xy)) - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" - def pairwise_pli(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - s_xy = _smooth_spectra(s_xy, kernel) - out = np.abs(np.mean(np.sign(np.imag(s_xy)), - axis=-1, keepdims=True)) + for i, _ in enumerate(out): # mean inside frequency sliding window (if needed) if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out + out[i] = _foi_average(out[i], foi_idx) + # squeeze time dimension + out[i] = out[i].squeeze(axis=-1) - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_pli, n_jobs=n_jobs, verbose=verbose, total=total) + return out - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _plv(s_xy): + s_xy = s_xy / np.abs(s_xy) + plv = np.abs(s_xy.mean(axis=-1, keepdims=True)) + return plv -def _wpli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise weighted phase-lag index. - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" - def pairwise_wpli(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - s_xy = _smooth_spectra(s_xy, kernel) - con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) - con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) - out = con_num / con_den - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out +def _ciplv(s_xy): + s_xy = s_xy / np.abs(s_xy) + rplv = np.abs(np.mean(np.real(s_xy), axis=-1, keepdims=True)) + iplv = np.abs(np.mean(np.imag(s_xy), axis=-1, keepdims=True)) + ciplv = iplv / (np.sqrt(1 - rplv ** 2)) + return ciplv - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_wpli, n_jobs=n_jobs, verbose=verbose, total=total) - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _pli(s_xy): + pli = np.abs(np.mean(np.sign(np.imag(s_xy)), + axis=-1, keepdims=True)) + return pli -def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise cross-spectra.""" - def pairwise_cs(w_x, w_y): - out = _compute_csd(w[:, w_y], w[:, w_x], weights) - out = _smooth_spectra(out, kernel) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out +def _wpli(s_xy): + con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) + con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) + wpli = con_num / con_den + return wpli - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_cs, n_jobs=n_jobs, verbose=verbose, total=total) - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _coh(s_xx, s_yy, s_xy): + con_num = np.abs(s_xy.mean(axis=-1, keepdims=True)) + con_den = np.sqrt(s_xx.mean(axis=-1, keepdims=True) * + s_yy.mean(axis=-1, keepdims=True)) + coh = con_num / con_den + return coh def _compute_csd(x, y, weights): - """Compute cross spectral density of signals x and y.""" + """Compute cross spectral density between signals x and y.""" if weights is not None: s_xy = np.sum(weights * x * np.conj(weights * y), axis=-3) s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=-3) @@ -609,15 +692,15 @@ def _foi_average(conn, foi_idx): Parameters ---------- - conn : np.ndarray - Array of shape (..., n_freqs, n_times) - foi_idx : array_like - Array of indices describing frequency bounds of shape (n_foi, 2) + conn : array_like, shape (..., n_freqs, n_times) + Connectivity estimate array. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower frequency bounds of each frequency band. Returns ------- - conn_f : np.ndarray - Array of shape (..., n_foi, n_times) + conn_f : np.ndarray, shape (..., n_fbands, n_times) + Connectivity estimate array, averaged within frequency bands. """ # get the number of foi n_foi = foi_idx.shape[0] diff --git a/requirements.txt b/requirements.txt index 067ea5c72..a564429a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy scipy -mne>=1.1 +mne>=1.3 xarray netCDF4 h5netcdf