Skip to content

Commit 8eaa521

Browse files
tsbinnslarsoner
andauthored
Add support for n-dimensional arrays in _tfr_from_mt (mne-tools#13104)
Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
1 parent e4cc4e2 commit 8eaa521

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

mne/time_frequency/tfr.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4291,19 +4291,20 @@ def _tfr_from_mt(x_mt, weights):
42914291
42924292
Parameters
42934293
----------
4294-
x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times)
4294+
x_mt : array, shape (..., n_tapers, n_freqs, n_times)
42954295
The complex-valued multitaper coefficients.
42964296
weights : array, shape (n_tapers, n_freqs)
42974297
The weights to use to combine the tapered estimates.
42984298
42994299
Returns
43004300
-------
4301-
tfr : array, shape (n_channels, n_freqs, n_times)
4301+
tfr : array, shape (..., n_freqs, n_times)
43024302
The time-frequency power estimates.
43034303
"""
4304-
weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims
4304+
# add singleton dim for time and any dims preceding the tapers
4305+
weights = weights[..., np.newaxis]
43054306
tfr = weights * x_mt
43064307
tfr *= tfr.conj()
4307-
tfr = tfr.real.sum(axis=1)
4308-
tfr *= 2 / (weights * weights.conj()).real.sum(axis=1)
4308+
tfr = tfr.real.sum(axis=-3)
4309+
tfr *= 2 / (weights * weights.conj()).real.sum(axis=-3)
43094310
return tfr

0 commit comments

Comments
 (0)