diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index eb512adfd..9e223baa7 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -25,7 +25,7 @@ _psd_from_mt, _psd_from_mt_adaptive, ) -from mne.time_frequency.tfr import _tfr_from_mt, cwt, morlet +from mne.time_frequency.tfr import cwt, morlet from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn from ..base import SpectralConnectivity, SpectroTemporalConnectivity @@ -464,8 +464,8 @@ def _tfr_csd_from_mt(x_mt, y_mt, weights_x, weights_y): The CSD between x and y. """ # expand weights dims to match x_mt and y_mt - weights_x = np.expand_dims(weights_x, axis=(*np.arange(x_mt.ndim - 3), -1)) - weights_y = np.expand_dims(weights_y, axis=(*np.arange(y_mt.ndim - 3), -1)) + weights_x = weights_x[..., np.newaxis] + weights_y = weights_y[..., np.newaxis] # compute CSD csd = np.sum(weights_x * x_mt * (weights_y * y_mt).conj(), axis=-3) denom = np.sqrt((weights_x * weights_x.conj()).real.sum(axis=-3)) * np.sqrt( @@ -566,7 +566,10 @@ def _epoch_spectral_connectivity( if not is_tfr_con: # normal spectra (multitaper or Fourier) this_psd = _psd_from_mt(x_t, weights) else: # TFR spectra (multitaper) - this_psd = np.array([_tfr_from_mt(epo_x, weights) for epo_x in x_t]) + # XXX: Move import to top when support for mne<1.10 is dropped + from mne.time_frequency.tfr import _tfr_from_mt + + this_psd = _tfr_from_mt(x_t, weights) else: # mode == 'cwt_morlet' this_psd = (x_t * x_t.conj()).real else: # compute spectral info from scratch diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index ada7234b3..19b3bf5b0 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -17,7 +17,6 @@ from mne.parallel import parallel_func from mne.time_frequency import EpochsSpectrum, EpochsTFR from mne.time_frequency.multitaper import _psd_from_mt -from mne.time_frequency.tfr import _tfr_from_mt from mne.utils import ProgressBar, _validate_type, logger @@ -46,9 +45,10 @@ def _check_rank_input(rank, data, indices): data_arr = data.get_data(picks=np.arange(data.info["nchan"])) # Convert to power and aggregate over time before computing rank if "taper" in data._dims: - data_arr = np.sum( - [_tfr_from_mt(epoch, data.weights) for epoch in data_arr], axis=-1 - ) + # XXX: Move import to top when support for mne<1.10 is dropped + from mne.time_frequency.tfr import _tfr_from_mt + + data_arr = np.sum(_tfr_from_mt(data_arr, data.weights), axis=-1) else: data_arr = np.sum((data_arr * data_arr.conj()).real, axis=-1) else: