Skip to content

Change calls to _tfr_from_mt with support for ndarrays #286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions mne_connectivity/spectral/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
_psd_from_mt,
_psd_from_mt_adaptive,
)
from mne.time_frequency.tfr import _tfr_from_mt, cwt, morlet
from mne.time_frequency.tfr import cwt, morlet
from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn

from ..base import SpectralConnectivity, SpectroTemporalConnectivity
Expand Down Expand Up @@ -464,8 +464,8 @@ def _tfr_csd_from_mt(x_mt, y_mt, weights_x, weights_y):
The CSD between x and y.
"""
# expand weights dims to match x_mt and y_mt
weights_x = np.expand_dims(weights_x, axis=(*np.arange(x_mt.ndim - 3), -1))
weights_y = np.expand_dims(weights_y, axis=(*np.arange(y_mt.ndim - 3), -1))
weights_x = weights_x[..., np.newaxis]
weights_y = weights_y[..., np.newaxis]
# compute CSD
csd = np.sum(weights_x * x_mt * (weights_y * y_mt).conj(), axis=-3)
denom = np.sqrt((weights_x * weights_x.conj()).real.sum(axis=-3)) * np.sqrt(
Expand Down Expand Up @@ -566,7 +566,10 @@ def _epoch_spectral_connectivity(
if not is_tfr_con: # normal spectra (multitaper or Fourier)
this_psd = _psd_from_mt(x_t, weights)
else: # TFR spectra (multitaper)
this_psd = np.array([_tfr_from_mt(epo_x, weights) for epo_x in x_t])
# XXX: Move import to top when support for mne<1.10 is dropped
from mne.time_frequency.tfr import _tfr_from_mt

this_psd = _tfr_from_mt(x_t, weights)
else: # mode == 'cwt_morlet'
this_psd = (x_t * x_t.conj()).real
else: # compute spectral info from scratch
Expand Down
8 changes: 4 additions & 4 deletions mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from mne.parallel import parallel_func
from mne.time_frequency import EpochsSpectrum, EpochsTFR
from mne.time_frequency.multitaper import _psd_from_mt
from mne.time_frequency.tfr import _tfr_from_mt
from mne.utils import ProgressBar, _validate_type, logger


Expand Down Expand Up @@ -46,9 +45,10 @@ def _check_rank_input(rank, data, indices):
data_arr = data.get_data(picks=np.arange(data.info["nchan"]))
# Convert to power and aggregate over time before computing rank
if "taper" in data._dims:
data_arr = np.sum(
[_tfr_from_mt(epoch, data.weights) for epoch in data_arr], axis=-1
)
# XXX: Move import to top when support for mne<1.10 is dropped
from mne.time_frequency.tfr import _tfr_from_mt

data_arr = np.sum(_tfr_from_mt(data_arr, data.weights), axis=-1)
else:
data_arr = np.sum((data_arr * data_arr.conj()).real, axis=-1)
else:
Expand Down