|
25 | 25 | _psd_from_mt,
|
26 | 26 | _psd_from_mt_adaptive,
|
27 | 27 | )
|
28 |
| -from mne.time_frequency.tfr import _tfr_from_mt, cwt, morlet |
| 28 | +from mne.time_frequency.tfr import cwt, morlet |
29 | 29 | from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn
|
30 | 30 |
|
31 | 31 | from ..base import SpectralConnectivity, SpectroTemporalConnectivity
|
@@ -464,8 +464,8 @@ def _tfr_csd_from_mt(x_mt, y_mt, weights_x, weights_y):
|
464 | 464 | The CSD between x and y.
|
465 | 465 | """
|
466 | 466 | # expand weights dims to match x_mt and y_mt
|
467 |
| - weights_x = np.expand_dims(weights_x, axis=(*np.arange(x_mt.ndim - 3), -1)) |
468 |
| - weights_y = np.expand_dims(weights_y, axis=(*np.arange(y_mt.ndim - 3), -1)) |
| 467 | + weights_x = weights_x[..., np.newaxis] |
| 468 | + weights_y = weights_y[..., np.newaxis] |
469 | 469 | # compute CSD
|
470 | 470 | csd = np.sum(weights_x * x_mt * (weights_y * y_mt).conj(), axis=-3)
|
471 | 471 | denom = np.sqrt((weights_x * weights_x.conj()).real.sum(axis=-3)) * np.sqrt(
|
@@ -566,7 +566,10 @@ def _epoch_spectral_connectivity(
|
566 | 566 | if not is_tfr_con: # normal spectra (multitaper or Fourier)
|
567 | 567 | this_psd = _psd_from_mt(x_t, weights)
|
568 | 568 | else: # TFR spectra (multitaper)
|
569 |
| - this_psd = np.array([_tfr_from_mt(epo_x, weights) for epo_x in x_t]) |
| 569 | + # XXX: Move import to top when support for mne<1.10 is dropped |
| 570 | + from mne.time_frequency.tfr import _tfr_from_mt |
| 571 | + |
| 572 | + this_psd = _tfr_from_mt(x_t, weights) |
570 | 573 | else: # mode == 'cwt_morlet'
|
571 | 574 | this_psd = (x_t * x_t.conj()).real
|
572 | 575 | else: # compute spectral info from scratch
|
|
0 commit comments