Skip to content

Commit d9c260c

Browse files
authored
Change calls to _tfr_from_mt with support for ndarrays (#286)
1 parent ecb31a9 commit d9c260c

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

mne_connectivity/spectral/epochs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
_psd_from_mt,
2626
_psd_from_mt_adaptive,
2727
)
28-
from mne.time_frequency.tfr import _tfr_from_mt, cwt, morlet
28+
from mne.time_frequency.tfr import cwt, morlet
2929
from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn
3030

3131
from ..base import SpectralConnectivity, SpectroTemporalConnectivity
@@ -464,8 +464,8 @@ def _tfr_csd_from_mt(x_mt, y_mt, weights_x, weights_y):
464464
The CSD between x and y.
465465
"""
466466
# expand weights dims to match x_mt and y_mt
467-
weights_x = np.expand_dims(weights_x, axis=(*np.arange(x_mt.ndim - 3), -1))
468-
weights_y = np.expand_dims(weights_y, axis=(*np.arange(y_mt.ndim - 3), -1))
467+
weights_x = weights_x[..., np.newaxis]
468+
weights_y = weights_y[..., np.newaxis]
469469
# compute CSD
470470
csd = np.sum(weights_x * x_mt * (weights_y * y_mt).conj(), axis=-3)
471471
denom = np.sqrt((weights_x * weights_x.conj()).real.sum(axis=-3)) * np.sqrt(
@@ -566,7 +566,10 @@ def _epoch_spectral_connectivity(
566566
if not is_tfr_con: # normal spectra (multitaper or Fourier)
567567
this_psd = _psd_from_mt(x_t, weights)
568568
else: # TFR spectra (multitaper)
569-
this_psd = np.array([_tfr_from_mt(epo_x, weights) for epo_x in x_t])
569+
# XXX: Move import to top when support for mne<1.10 is dropped
570+
from mne.time_frequency.tfr import _tfr_from_mt
571+
572+
this_psd = _tfr_from_mt(x_t, weights)
570573
else: # mode == 'cwt_morlet'
571574
this_psd = (x_t * x_t.conj()).real
572575
else: # compute spectral info from scratch

mne_connectivity/spectral/epochs_multivariate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from mne.parallel import parallel_func
1818
from mne.time_frequency import EpochsSpectrum, EpochsTFR
1919
from mne.time_frequency.multitaper import _psd_from_mt
20-
from mne.time_frequency.tfr import _tfr_from_mt
2120
from mne.utils import ProgressBar, _validate_type, logger
2221

2322

@@ -46,9 +45,10 @@ def _check_rank_input(rank, data, indices):
4645
data_arr = data.get_data(picks=np.arange(data.info["nchan"]))
4746
# Convert to power and aggregate over time before computing rank
4847
if "taper" in data._dims:
49-
data_arr = np.sum(
50-
[_tfr_from_mt(epoch, data.weights) for epoch in data_arr], axis=-1
51-
)
48+
# XXX: Move import to top when support for mne<1.10 is dropped
49+
from mne.time_frequency.tfr import _tfr_from_mt
50+
51+
data_arr = np.sum(_tfr_from_mt(data_arr, data.weights), axis=-1)
5252
else:
5353
data_arr = np.sum((data_arr * data_arr.conj()).real, axis=-1)
5454
else:

0 commit comments

Comments
 (0)