diff --git a/doc/_includes/ged.rst b/doc/_includes/ged.rst new file mode 100644 index 00000000000..8f5fc17131c --- /dev/null +++ b/doc/_includes/ged.rst @@ -0,0 +1,107 @@ +:orphan: + +Generalized eigendecomposition in decoding +========================================== + +.. NOTE: part of this file is included in doc/overview/implementation.rst. + Changes here are reflected there. If you want to link to this content, link + to :ref:`ged` to link to that section of the implementation.rst page. + The next line is a target for :start-after: so we can omit the title from + the include: + ged-begin-content + +This section describes the mathematical formulation and application of +Generalized Eigendecomposition (GED), often used in spatial filtering +and source separation algorithms, such as :class:`mne.decoding.CSP`, +:class:`mne.decoding.SPoC`, :class:`mne.decoding.SSD` and +:class:`mne.preprocessing.Xdawn`. + +The core principle of GED is to find a set of channel weights (spatial filter) +that maximizes the ratio of signal power between two data features. +These features are defined by the researcher and are represented by two covariance matrices: +a "signal" matrix :math:`S` and a "reference" matrix :math:`R`. +For example, :math:`S` could be the covariance of data from a task time interval, +and :math:`S` could be the covariance from a baseline time interval. For more details see :footcite:`Cohen2022`. + +Algebraic formulation of GED +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A few definitions first: +Let :math:`n \in \mathbb{N}^+` be a number of channels. +Let :math:`\text{Symm}_n(\mathbb{R}) \subset M_n(\mathbb{R})` be a vector space of real symmetric matrices. +Let :math:`S^n_+, S^n_{++} \subset \text{Symm}_n(\mathbb{R})` be sets of real positive semidefinite and positive definite matrices, respectively. +Let :math:`S, R \in S^n_+` be covariance matrices estimated from electrophysiological data :math:`X_S \in M_{n \times t_S}(\mathbb{R})` and :math:`X_R \in M_{n \times t_R}(\mathbb{R})`. + +GED (or simultaneous diagonalization by congruence) of :math:`S` and :math:`R` +is possible when :math:`R` is full rank (and thus :math:`R \in S^n_{++}`): + +.. math:: + + SW = RWD, + +where :math:`W \in M_n(\mathbb{R})` is an invertible matrix of eigenvectors +of :math:`(S, R)` and :math:`D` is a diagonal matrix of eigenvalues :math:`\lambda_i`. + +Each eigenvector :math:`\mathbf{w} \in W` is a spatial filter that solves +an optimization problem of the form: + +.. math:: + + \operatorname{argmax}_{\mathbf{w}} \frac{\mathbf{w}^t S \mathbf{w}}{\mathbf{w}^t R \mathbf{w}} + +That is, using spatial filters :math:`W` on time-series :math:`X \in M_{n \times t}(\mathbb{R})`: + +.. math:: + + \mathbf{A} = W^t X, + +results in "activation" time-series :math:`A` of the estimated "sources", +such that the ratio of their variances, +:math:`\frac{\text{Var}(\mathbf{w}^T X_S)}{\text{Var}(\mathbf{w}^T X_R)} = \frac{\mathbf{w}^T S \mathbf{w}}{\mathbf{w}^T R \mathbf{w}}`, +is sequentially maximized spatial filters :math:`\mathbf{w}_i`, sorted according to :math:`\lambda_i`. + +GED in the principal subspace +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Unfortunately, :math:`R` might not be full rank depending on the data :math:`X_R` (for example due to average reference, removed PCA/ICA components, etc.). +In such cases, GED can be performed on :math:`S` and :math:`R` in the principal subspace :math:`Q = \operatorname{Im}(C_{ref}) \subset \mathbb{R}^n` of some reference +covariance :math:`C_{ref}` (in Common Spatial Pattern (CSP) algorithm, for example, :math:`C_{ref}=\frac{1}{2}(S+R)` and GED is performed on S and R'=S+R). + +More formally: +Let :math:`r \leq n` be a rank of :math:`C \in S^n_+`. +Let :math:`Q=\operatorname{Im}(C_{ref})` be a principal subspace of :math:`C_{ref}`. +Let :math:`P \in M_{n \times r}(\mathbb{R})` be an isometry formed by orthonormal basis of :math:`Q`. +Let :math:`f:S^n_+ \to S^r_+`, :math:`A \mapsto P^t A P` be a "restricting" map, that restricts quadratic form +:math:`q_A:\mathbb{R}^n \to \mathbb{R}` to :math:`q_{A|_Q}:\mathbb{R}^n \to \mathbb{R}` (in practical terms, :math:`q_A` maps +spatial filters to variance of the spatially filtered data :math:`X_A`). + +Then, the GED of :math:`S` and :math:`R` in the principal subspace :math:`Q` of :math:`C_{ref}` is performed as follows: + +1. :math:`S` and :math:`R` are transformed to :math:`S_Q = f(S) = P^t S P` and :math:`R_Q = f(R) = P^t R P`, + such that :math:`S_Q` and :math:`R_Q` are matrix representations of restricted :math:`q_{S|_Q}` and :math:`q_{R|_Q}`. +2. GED is performed on :math:`S_Q` and :math:`R_Q`: :math:`S_Q W_Q = R_Q W_Q D`. +3. Eigenvectors :math:`W_Q` of :math:`(S_Q, R_Q)` are transformed back to :math:`\mathbb{R}^n` + by :math:`W = P W_Q \in \mathbb{R}^{n \times r}` to obtain :math:`r` spatial filters. + +Note that the solution to the original optimization problem is preserved: + +.. math:: + + \frac{\mathbf{w_Q}^t S_Q \mathbf{w_Q}}{\mathbf{w_Q}^t R_Q \mathbf{w_Q}}= \frac{\mathbf{w_Q}^t (P^t S P) \mathbf{w_Q}}{\mathbf{w_Q}^t (P^t R P) + \mathbf{w_Q}} = \frac{\mathbf{w}^t S \mathbf{w}}{\mathbf{w}^t R \mathbf{w}} = \lambda + + +In addition to restriction, :math:`q_S` and :math:`q_R` can be rescaled based on the whitened :math:`C_{ref}`. +In this case the whitening map :math:`f_{wh}:S^n_+ \to S^r_+`, +:math:`A \mapsto P_{wh}^t A P_{wh}` transforms :math:`A` into matrix representation of :math:`q_{A|Q}` rescaled according to :math:`\Lambda^{-1/2}`, +where :math:`\Lambda` is a diagonal matrix of eigenvalues of :math:`C_{ref}` and so :math:`P_{wh} = P \Lambda^{-1/2}`. + +In MNE-Python, the matrix :math:`P` of the restricting map can be obtained using +:: + + _, ref_evecs, mask = mne.cov._smart_eigh(C_ref, ..., proj_subspace=True, ...) + restr_mat = ref_evecs[mask] + +while :math:`P_{wh}` using: +:: + + restr_mat = compute_whitener(C_ref, ..., pca=True, ...) \ No newline at end of file diff --git a/doc/api/decoding.rst b/doc/api/decoding.rst index c844afc470a..788a62b42da 100644 --- a/doc/api/decoding.rst +++ b/doc/api/decoding.rst @@ -29,6 +29,7 @@ Decoding GeneralizingEstimator SPoC SSD + XdawnTransformer Functions that assist with decoding and model fitting: diff --git a/doc/changes/devel/13259.newfeature.rst b/doc/changes/devel/13259.newfeature.rst new file mode 100644 index 00000000000..d510015e26f --- /dev/null +++ b/doc/changes/devel/13259.newfeature.rst @@ -0,0 +1,3 @@ +Implement GEDTransformer superclass that generalizes +:class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, :class:`mne.decoding.XdawnTransformer`, +:class:`mne.decoding.SSD` and fix related bugs and inconsistencies, by `Gennadiy Belonosov`_. \ No newline at end of file diff --git a/doc/documentation/implementation.rst b/doc/documentation/implementation.rst index ebae0201f7a..49fe31bac9c 100644 --- a/doc/documentation/implementation.rst +++ b/doc/documentation/implementation.rst @@ -124,6 +124,14 @@ Morphing and averaging source estimates :start-after: morph-begin-content +.. _ged: + +Generalized eigendecomposition in decoding +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. include:: ../_includes/ged.rst + :start-after: ged-begin-content + References ^^^^^^^^^^ .. footbibliography:: diff --git a/doc/references.bib b/doc/references.bib index f0addb5f3b2..3b7ddf6dc04 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -282,6 +282,18 @@ @article{Cohen2019 year = {2019} } +@article{Cohen2022, +author = {Cohen, Michael X}, +doi = {10.1016/j.neuroimage.2021.118809}, +journal = {NeuroImage}, +pages = {118809}, +title = {A tutorial on generalized eigendecomposition for denoising, contrast enhancement, and dimension reduction in multichannel electrophysiology}, +volume = {247}, +year = {2022}, +issn = {1053-8119}, + +} + @article{CohenHosaka1976, author = {Cohen, David and Hosaka, Hidehiro}, doi = {10.1016/S0022-0736(76)80041-6}, diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi index 2b6c89b2140..6a1e7d8ab89 100644 --- a/mne/decoding/__init__.pyi +++ b/mne/decoding/__init__.pyi @@ -17,6 +17,7 @@ __all__ = [ "TransformerMixin", "UnsupervisedSpatialFilter", "Vectorizer", + "XdawnTransformer", "compute_ems", "cross_val_multiscore", "get_coef", @@ -43,3 +44,4 @@ from .transformer import ( UnsupervisedSpatialFilter, Vectorizer, ) +from .xdawn import XdawnTransformer diff --git a/mne/decoding/_covs_ged.py b/mne/decoding/_covs_ged.py new file mode 100644 index 00000000000..074b4c15625 --- /dev/null +++ b/mne/decoding/_covs_ged.py @@ -0,0 +1,284 @@ +"""Covariance estimation for GED transformers.""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import numpy as np + +from .._fiff.meas_info import Info, create_info +from .._fiff.pick import _picks_to_idx, pick_info +from ..cov import Covariance, _compute_rank_raw_array, _regularized_covariance +from ..defaults import _handle_default +from ..filter import filter_data +from ..rank import compute_rank +from ..utils import _verbose_safe_false, logger + + +def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, info, rank): + """Concatenate epochs before computing the covariance.""" + _, n_channels, _ = x_class.shape + + x_class = x_class.transpose(1, 0, 2).reshape(n_channels, -1) + cov = _regularized_covariance( + x_class, + reg=reg, + method_params=cov_method_params, + rank=rank, + info=info, + cov_kind=cov_kind, + log_rank=log_rank, + log_ch_type="data", + ) + + return cov, n_channels # the weight here is just the number of channels + + +def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, info, rank): + """Mean of per-epoch covariances.""" + name = reg if isinstance(reg, str) else "empirical" + name += " with shrinkage" if isinstance(reg, float) else "" + logger.info( + f"Estimating {cov_kind + (' ' if cov_kind else '')}" + f"covariance (average over epochs; {name.upper()})" + ) + cov = sum( + _regularized_covariance( + this_X, + reg=reg, + method_params=cov_method_params, + rank=rank, + info=info, + cov_kind=cov_kind, + log_rank=log_rank and ii == 0, + log_ch_type="data", + verbose=_verbose_safe_false(), + ) + for ii, this_X in enumerate(x_class) + ) + cov /= len(x_class) + weight = len(x_class) + + return cov, weight + + +def _handle_info_rank(X, info, rank): + if info is None: + # use mag instead of eeg to avoid the cov EEG projection warning + info = create_info(X.shape[1], 1000.0, "mag") + if isinstance(rank, dict): + rank = dict(mag=sum(rank.values())) + + return info, rank + + +def _csp_estimate(X, y, reg, cov_method_params, cov_est, info, rank, norm_trace): + _, n_channels, _ = X.shape + classes_ = np.unique(y) + if cov_est == "concat": + cov_estimator = _concat_cov + elif cov_est == "epoch": + cov_estimator = _epoch_cov + + info, rank = _handle_info_rank(X, info, rank) + if not isinstance(rank, dict): + rank = _compute_rank_raw_array( + np.hstack(X), + info, + rank=rank, + scalings=None, + log_ch_type="data", + ) + + covs = [] + sample_weights = [] + for ci, this_class in enumerate(classes_): + cov, weight = cov_estimator( + X[y == this_class], + cov_kind=f"class={this_class}", + log_rank=ci == 0, + reg=reg, + cov_method_params=cov_method_params, + info=info, + rank=rank, + ) + + if norm_trace: + cov /= np.trace(cov) + + covs.append(cov) + sample_weights.append(weight) + + covs = np.stack(covs) + C_ref = covs.mean(0) + + return covs, C_ref, info, rank, dict(sample_weights=np.array(sample_weights)) + + +def _xdawn_estimate( + X, + y, + reg, + cov_method_params, + R=None, + info=None, + rank="full", +): + classes = np.unique(y) + info, rank = _handle_info_rank(X, info, rank) + + # Retrieve or compute whitening covariance + if R is None: + R = _regularized_covariance( + np.hstack(X), reg, cov_method_params, info, rank=rank + ) + elif isinstance(R, Covariance): + R = R.data + + # Get prototype events + evokeds, toeplitzs = list(), list() + for c in classes: + # Prototyped response for each class + evokeds.append(np.mean(X[y == c, :, :], axis=0)) + toeplitzs.append(1.0) + + covs = [] + for evo, toeplitz in zip(evokeds, toeplitzs): + # Estimate covariance matrix of the prototype response + evo = np.dot(evo, toeplitz) + evo_cov = _regularized_covariance(evo, reg, cov_method_params, info, rank=rank) + covs.append(evo_cov) + + covs.append(R) + C_ref = R + if not isinstance(rank, dict): + rank = _compute_rank_raw_array( + np.hstack(X), + info, + rank=rank, + scalings=None, + log_ch_type="data", + ) + return covs, C_ref, info, rank, dict() + + +def _ssd_estimate( + X, + y, + reg, + cov_method_params, + info, + picks, + n_fft, + filt_params_signal, + filt_params_noise, + rank, + sort_by_spectral_ratio, +): + if isinstance(info, Info): + sfreq = info["sfreq"] + elif isinstance(info, float): # special case, mostly for testing + sfreq = info + info = create_info(X.shape[-2], sfreq, ch_types="eeg") + picks_ = _picks_to_idx(info, picks, none="data", exclude="bads") + X_aux = X[..., picks_, :] + X_signal = filter_data(X_aux, sfreq, **filt_params_signal) + X_noise = filter_data(X_aux, sfreq, **filt_params_noise) + X_noise -= X_signal + if X.ndim == 3: + X_signal = np.hstack(X_signal) + X_noise = np.hstack(X_noise) + + # prevent rank change when computing cov with rank='full' + picked_info = pick_info(info, picks_) + S = _regularized_covariance( + X_signal, + reg=reg, + method_params=cov_method_params, + rank="full", + info=picked_info, + ) + R = _regularized_covariance( + X_noise, + reg=reg, + method_params=cov_method_params, + rank="full", + info=picked_info, + ) + covs = [S, R] + C_ref = S + + all_ranks = list() + for cov in covs: + r = list( + compute_rank( + Covariance( + cov, + picked_info.ch_names, + list(), + list(), + 0, + verbose=_verbose_safe_false(), + ), + rank, + _handle_default("scalings_cov_rank", None), + info, + ).values() + )[0] + all_ranks.append(r) + rank = np.min(all_ranks) + freqs_signal = (filt_params_signal["l_freq"], filt_params_signal["h_freq"]) + freqs_noise = (filt_params_noise["l_freq"], filt_params_noise["h_freq"]) + n_fft = min( + int(n_fft if n_fft is not None else sfreq), + X.shape[-1], + ) + kwargs = dict( + X=X, + picks=picks_, + sfreq=sfreq, + n_fft=n_fft, + freqs_signal=freqs_signal, + freqs_noise=freqs_noise, + sort_by_spectral_ratio=sort_by_spectral_ratio, + ) + rank = dict(eeg=rank) + info = picked_info + return covs, C_ref, info, rank, kwargs + + +def _spoc_estimate(X, y, reg, cov_method_params, info, rank): + info, rank = _handle_info_rank(X, info, rank) + # Normalize target variable + target = y.astype(np.float64) + target -= target.mean() + target /= target.std() + + n_epochs, n_channels = X.shape[:2] + + # Estimate single trial covariance + covs = np.empty((n_epochs, n_channels, n_channels)) + for ii, epoch in enumerate(X): + covs[ii] = _regularized_covariance( + epoch, + reg=reg, + method_params=cov_method_params, + rank=rank, + log_ch_type="data", + log_rank=ii == 0, + ) + + S = np.mean(covs * target[:, np.newaxis, np.newaxis], axis=0) + R = covs.mean(0) + + covs = [S, R] + C_ref = R + if not isinstance(rank, dict): + rank = _compute_rank_raw_array( + np.hstack(X), + info, + rank=rank, + scalings=None, + log_ch_type="data", + ) + return covs, C_ref, info, rank, dict() diff --git a/mne/decoding/_ged.py b/mne/decoding/_ged.py new file mode 100644 index 00000000000..c50b8d101e0 --- /dev/null +++ b/mne/decoding/_ged.py @@ -0,0 +1,131 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import numpy as np +import scipy.linalg + +from ..cov import Covariance, _smart_eigh, compute_whitener + + +def _handle_restr_mat(C_ref, restr_type, info, rank): + """Get restricting matrix to C_ref rank-dimensional principal subspace. + + Returns matrix of shape (rank, n_chs) used to restrict or + restrict+rescale (whiten) covariances matrices. + """ + if C_ref is None or restr_type is None: + return None + if restr_type == "whitening": + C_ref_cov = Covariance(C_ref, info.ch_names, info["bads"], info["projs"], 0) + restr_mat = compute_whitener( + C_ref_cov, info, rank=rank, pca=True, verbose="error" + )[0] + elif restr_type == "restricting": + restr_mat = _get_restr_mat(C_ref, info, rank) + else: + raise ValueError( + "restr_type should either be callable or one of " + "('whitening', 'restricting')" + ) + return restr_mat + + +def _smart_ged(S, R, restr_mat=None, R_func=None): + """Perform smart generalized eigenvalue decomposition (GED) of S and R. + + If restr_mat is provided S and R will be restricted to the principal subspace + of a reference matrix with rank r (see _handle_restr_mat), then GED is performed + on the restricted S and R and then generalized eigenvectors are transformed back + to the original space. The g-eigenvectors matrix is of shape (n_chs, r). + If callable R_func is provided the GED will be performed on (S, R_func(S,R)) + """ + if restr_mat is None: + evals, evecs = scipy.linalg.eigh(S, R) + return evals, evecs + + S_restr = restr_mat @ S @ restr_mat.T + R_restr = restr_mat @ R @ restr_mat.T + if R_func is not None: + R_restr = R_func([S_restr, R_restr]) + evals, evecs_restr = scipy.linalg.eigh(S_restr, R_restr) + evecs = restr_mat.T @ evecs_restr + + return evals, evecs + + +def _is_cov_symm_pos_semidef( + cov, rtol=1e-10, atol=1e-11, eval_tol=1e-15, check_pos_semidef=True +): + is_symm = scipy.linalg.issymmetric(cov, rtol=rtol, atol=atol) + if not is_symm: + return False + + if check_pos_semidef: + # numerically slightly negative evals are considered 0 + is_pos_semidef = np.all(scipy.linalg.eigvalsh(cov) >= -eval_tol) + return is_pos_semidef + + return True + + +def _is_cov_pos_def(cov, eval_tol=1e-15): + is_symm = _is_cov_symm_pos_semidef(cov, check_pos_semidef=False) + if not is_symm: + return False + # numerically slightly positive evals are considered 0 + is_pos_def = np.all(scipy.linalg.eigvalsh(cov) > eval_tol) + return is_pos_def + + +def _smart_ajd(covs, restr_mat=None, weights=None): + """Perform smart approximate joint diagonalization. + + If restr_mat is provided all the cov matrices will be restricted to the + principal subspace of a reference matrix with rank r (see _handle_restr_mat), + then GED is performed on the restricted S and R and then generalized eigenvectors + are transformed back to the original space. + The matrix of generalized eigenvectors is of shape (n_chs, r). + """ + from .csp import _ajd_pham + + if restr_mat is None: + is_all_pos_def = all([_is_cov_pos_def(cov) for cov in covs]) + if not is_all_pos_def: + raise ValueError( + "If C_ref is not provided by covariance estimator, " + "all the covs should be positive definite" + ) + evecs, D = _ajd_pham(covs) + return evecs + + else: + covs = np.array([restr_mat @ cov @ restr_mat.T for cov in covs], float) + evecs_restr, D = _ajd_pham(covs) + evecs = _normalize_eigenvectors(evecs_restr.T, covs, weights) + evecs = restr_mat.T @ evecs + return evecs + + +def _get_restr_mat(C, info, rank): + """Get matrix restricting covariance to rank-dimensional principal subspace of C.""" + _, ref_evecs, mask = _smart_eigh( + C, + info, + rank, + proj_subspace=True, + do_compute_rank=False, + log_ch_type="data", + ) + restr_mat = ref_evecs[mask] + return restr_mat + + +def _normalize_eigenvectors(evecs, covs, sample_weights): + # Here we apply an euclidean mean. See pyRiemann for other metrics + mean_cov = np.average(covs, axis=0, weights=sample_weights) + + for ii in range(evecs.shape[1]): + tmp = np.dot(np.dot(evecs[:, ii].T, mean_cov), evecs[:, ii]) + evecs[:, ii] /= np.sqrt(tmp) + return evecs diff --git a/mne/decoding/_mod_ged.py b/mne/decoding/_mod_ged.py new file mode 100644 index 00000000000..df917a78ae3 --- /dev/null +++ b/mne/decoding/_mod_ged.py @@ -0,0 +1,132 @@ +"""Eigenvalue eigenvector modifiers for GED transformers.""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import numpy as np + +from ..time_frequency import psd_array_welch +from ..utils import _time_mask + + +def _compute_mutual_info(covs, sample_weights, evecs): + class_probas = sample_weights / sample_weights.sum() + + mutual_info = [] + for jj in range(evecs.shape[1]): + aa, bb = 0, 0 + for cov, prob in zip(covs, class_probas): + tmp = np.dot(np.dot(evecs[:, jj].T, cov), evecs[:, jj]) + aa += prob * np.log(np.sqrt(tmp)) + bb += prob * (tmp**2 - 1) + mi = -(aa + (3.0 / 16) * (bb**2)) + mutual_info.append(mi) + + return mutual_info + + +def _csp_mod(evals, evecs, covs, evecs_order, sample_weights): + n_classes = sample_weights.shape[0] + if evecs_order == "mutual_info" and n_classes > 2: + mutual_info = _compute_mutual_info(covs, sample_weights, evecs) + ix = np.argsort(mutual_info)[::-1] + elif evecs_order == "mutual_info" and n_classes == 2: + ix = np.argsort(np.abs(evals - 0.5))[::-1] + elif evecs_order == "alternate" and n_classes == 2: + i = np.argsort(evals) + ix = np.empty_like(i) + ix[1::2] = i[: len(i) // 2] + ix[0::2] = i[len(i) // 2 :][::-1] + if evals is not None: + evals = evals[ix] + evecs = evecs[:, ix] + sorter = ix + return evals, evecs, sorter + + +def _xdawn_mod(evals, evecs, covs=None): + evals, evecs, sorter = _sort_descending(evals, evecs) + evecs /= np.linalg.norm(evecs, axis=0) + return evals, evecs, sorter + + +def _get_spectral_ratio(ssd_sources, sfreq, n_fft, freqs_signal, freqs_noise): + """Get the spectal signal-to-noise ratio for each spatial filter. + + Spectral ratio measure for best n_components selection + See :footcite:`NikulinEtAl2011`, Eq. (24). + + Returns + ------- + spec_ratio : array, shape (n_channels) + Array with the sprectal ratio value for each component. + sorter_spec : array, shape (n_channels) + Array of indices for sorting spec_ratio. + + References + ---------- + .. footbibliography:: + """ + psd, freqs = psd_array_welch(ssd_sources, sfreq=sfreq, n_fft=n_fft) + sig_idx = _time_mask(freqs, *freqs_signal) + noise_idx = _time_mask(freqs, *freqs_noise) + if psd.ndim == 3: + mean_sig = psd[:, :, sig_idx].mean(axis=2).mean(axis=0) + mean_noise = psd[:, :, noise_idx].mean(axis=2).mean(axis=0) + spec_ratio = mean_sig / mean_noise + else: + mean_sig = psd[:, sig_idx].mean(axis=1) + mean_noise = psd[:, noise_idx].mean(axis=1) + spec_ratio = mean_sig / mean_noise + sorter_spec = spec_ratio.argsort()[::-1] + return spec_ratio, sorter_spec + + +def _ssd_mod( + evals, + evecs, + covs, + X, + picks, + sfreq, + n_fft, + freqs_signal, + freqs_noise, + sort_by_spectral_ratio, +): + evals, evecs, sorter = _sort_descending(evals, evecs) + if sort_by_spectral_ratio: + # We assume that ordering by spectral ratio is more important + # than the initial ordering. + filters = evecs.T + ssd_sources = filters @ X[..., picks, :] + _, sorter_spec = _get_spectral_ratio( + ssd_sources, sfreq, n_fft, freqs_signal, freqs_noise + ) + evecs = evecs[:, sorter_spec] + evals = evals[sorter_spec] + sorter = sorter_spec + return evals, evecs, sorter + + +def _spoc_mod(evals, evecs, covs=None): + evals = evals.real + evecs = evecs.real + evals, evecs, sorter = _sort_descending(evals, evecs, by_abs=True) + return evals, evecs, sorter + + +def _sort_descending(evals, evecs, by_abs=False): + if by_abs: + ix = np.argsort(np.abs(evals))[::-1] + else: + ix = np.argsort(evals)[::-1] + evals = evals[ix] + evecs = evecs[:, ix] + sorter = ix + return evals, evecs, sorter + + +def _no_op_mod(evals, evecs, *args, **kwargs): + return evals, evecs, None diff --git a/mne/decoding/base.py b/mne/decoding/base.py index f73cd976fe3..f207a558bcd 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -6,6 +6,8 @@ import datetime as dt import numbers +from functools import partial +from inspect import Parameter, signature import numpy as np from sklearn import model_selection as models @@ -20,9 +22,258 @@ from sklearn.metrics import check_scoring from sklearn.model_selection import KFold, StratifiedKFold, check_cv from sklearn.utils import check_array, check_X_y, indexable +from sklearn.utils.validation import check_is_fitted from ..parallel import parallel_func -from ..utils import _pl, logger, verbose, warn +from ..utils import _check_option, _pl, _validate_type, logger, pinv, verbose, warn +from ._ged import _handle_restr_mat, _is_cov_symm_pos_semidef, _smart_ajd, _smart_ged +from ._mod_ged import _no_op_mod +from .transformer import MNETransformerMixin + + +class _GEDTransformer(MNETransformerMixin, BaseEstimator): + """M/EEG signal decomposition using the generalized eigenvalue decomposition (GED). + + Given two channel covariance matrices S and R, the goal is to find spatial filters + that maximise contrast between S and R. + + Parameters + ---------- + n_components : int | None + The number of spatial filters to decompose M/EEG signals. + If None, all of the components will be used for transformation. + Defaults to None. + cov_callable : callable + Function used to estimate covariances and reference matrix (C_ref) from the + data. The only required arguments should be 'X' and optionally 'y'. The function + should return covs, C_ref, info, rank and additional kwargs passed further + to mod_ged_callable. C_ref, info, rank can be None and kwargs can be empty dict. + mod_ged_callable : callable | None + Function used to modify (e.g. sort or normalize) generalized + eigenvalues and eigenvectors. It should accept as arguments evals, evecs + and also covs and optional kwargs returned by cov_callable. It should return + sorted and/or modified evals and evecs and the list of indices according + to which the first two were sorted. If None, evals and evecs will be + ordered according to :func:`~scipy.linalg.eigh` default. Defaults to None. + dec_type : "single" | "multi" + When "single" and cov_callable returns > 2 covariances, + approximate joint diagonalization based on Pham's algorithm + will be used instead of GED. + When 'multi', GED is performed separately for each class, i.e. each covariance + (except the last) returned by cov_callable is decomposed with the last + covariance. In this case, number of covariances should be number of classes + 1. + Defaults to "single". + restr_type : "restricting" | "whitening" | None + Restricting transformation for covariance matrices before performing GED. + If "restricting" only restriction to the principal subspace of the C_ref + will be performed. + If "whitening", covariance matrices will be additionally rescaled according + to the whitening for the C_ref. + If None, no restriction will be applied. Defaults to None. + R_func : callable | None + If provided, GED will be performed on (S, R_func([S,R])). When dec_type is + "single", R_func applicable only if two covariances returned by cov_callable. + If None, GED is performed on (S, R). Defaults to None. + + Attributes + ---------- + evals_ : ndarray, shape (n_channels) + If fit, generalized eigenvalues used to decompose S and R, else None. + filters_ : ndarray, shape (n_channels or less, n_channels) + If fit, spatial filters (unmixing matrix) used to decompose the data, + else None. + patterns_ : ndarray, shape (n_channels or less, n_channels) + If fit, spatial patterns (mixing matrix) used to restore M/EEG signals, + else None. + + See Also + -------- + CSP + SPoC + SSD + + Notes + ----- + .. versionadded:: 1.10 + """ + + def __init__( + self, + cov_callable=None, + n_components=None, + mod_ged_callable=None, + dec_type="single", + restr_type=None, + R_func=None, + ): + self.n_components = n_components + self.cov_callable = cov_callable + self.mod_ged_callable = mod_ged_callable + self.dec_type = dec_type + self.restr_type = restr_type + self.R_func = R_func + + _is_base_ged = True + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls._is_base_ged = False + + def fit(self, X, y=None): + """...""" + # Let the inheriting transformers check data by themselves + if self._is_base_ged: + X, y = self._check_data( + X, + y=y, + fit=True, + return_y=True, + ) + self._validate_ged_params() + covs, C_ref, info, rank, kwargs = self.cov_callable(X, y) + covs = np.stack(covs) + self._validate_covariances(covs) + self._validate_covariances([C_ref]) + mod_ged_callable = ( + self.mod_ged_callable if self.mod_ged_callable is not None else _no_op_mod + ) + + if self.dec_type == "single": + if len(covs) > 2: + weights = ( + kwargs["sample_weights"] if "sample_weights" in kwargs else None + ) + restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank) + evecs = _smart_ajd(covs, restr_mat, weights=weights) + evals = None + else: + S = covs[0] + R = covs[1] + restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank) + evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func) + + evals, evecs, self.sorter_ = mod_ged_callable(evals, evecs, covs, **kwargs) + self.evals_ = evals + self.filters_ = evecs.T + self.patterns_ = pinv(evecs) + + elif self.dec_type == "multi": + self.classes_ = np.unique(y) + R = covs[-1] + restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank) + all_evals, all_evecs = list(), list() + all_patterns, all_sorters = list(), list() + for i in range(len(self.classes_)): + S = covs[i] + + evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func) + + evals, evecs, sorter = mod_ged_callable(evals, evecs, covs, **kwargs) + all_evals.append(evals) + all_evecs.append(evecs.T) + all_patterns.append(pinv(evecs)) + all_sorters.append(sorter) + self.sorter_ = np.array(all_sorters) + self.evals_ = np.array(all_evals) + self.filters_ = np.array(all_evecs) + self.patterns_ = np.array(all_patterns) + + return self + + def transform(self, X): + """...""" + check_is_fitted(self, "filters_") + # Let the inheriting transformers check data by themselves + if self._is_base_ged: + X = self._check_data(X) + if self.dec_type == "single": + pick_filters = self.filters_[: self.n_components] + elif self.dec_type == "multi": + pick_filters = self._subset_multi_components() + X = pick_filters @ X + return X + + def _subset_multi_components(self, name="filters"): + # The shape of stored filters and patterns is + # is (n_classes, n_evecs, n_chs) + # Transform and subset into (n_classes*n_components, n_chs) + if name == "filters": + return self.filters_[:, : self.n_components, :].reshape( + -1, self.filters_.shape[2] + ) + elif name == "patterns": + return self.patterns_[:, : self.n_components, :].reshape( + -1, self.patterns_.shape[2] + ) + return None + + def _validate_required_args(self, func, desired_required_args): + sig = signature(func) + actual_required_args = [ + param.name + for param in sig.parameters.values() + if param.default is Parameter.empty + ] + func_name = func.func.__name__ if isinstance(func, partial) else func.__name__ + if not all(arg in desired_required_args for arg in actual_required_args): + raise ValueError( + f"Invalid required arguments for '{func_name}'. " + f"The only allowed required arguments are {desired_required_args}, " + f"but got {actual_required_args} instead." + ) + + def _validate_ged_params(self): + # Naming is GED-specific so that the validation is still executed + # when child classes run super().fit() + + _validate_type(self.n_components, (int, None), "n_components") + if self.n_components is not None and self.n_components <= 0: + raise ValueError( + "Invalid value for the 'n_components' parameter. " + "Allowed are positive integers or None, " + "but got a non-positive integer instead." + ) + + self._validate_required_args( + self.cov_callable, desired_required_args=["X", "y"] + ) + + _check_option( + "dec_type", + self.dec_type, + ("single", "multi"), + ) + + _check_option( + "restr_type", + self.restr_type, + ("restricting", "whitening", None), + ) + + def _validate_covariances(self, covs): + for cov in covs: + if cov is None: + continue + # XXX: A lot of mne.decoding classes use mne.cov._regularized_covariance. + # Depending on the data it sometimes returns negative semidefinite matrices. + # So adding the validation of positive semidefinitiveness + # will require overhauling covariance estimation first. + is_cov = _is_cov_symm_pos_semidef(cov, check_pos_semidef=False) + if not is_cov: + raise ValueError( + "One of covariances is not symmetric (or positive semidefinite), " + "check your cov_callable" + ) + + def __sklearn_tags__(self): + """Tag the transformer.""" + tags = super().__sklearn_tags__() + # Can be a transformer where S and R covs are not based on y classes. + tags.target_tags.required = False + tags.target_tags.one_d_labels = True + tags.input_tags.two_d_array = True + tags.input_tags.three_d_array = True + return tags class LinearModel(MetaEstimatorMixin, BaseEstimator): diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index ebcc5e4ab61..7855a9ebe60 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -2,30 +2,27 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import collections.abc as abc import copy as cp +from functools import partial import numpy as np -from scipy.linalg import eigh -from sklearn.base import BaseEstimator -from sklearn.utils.validation import check_is_fitted -from .._fiff.meas_info import create_info -from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh +from .._fiff.meas_info import Info from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ..evoked import EvokedArray from ..utils import ( _check_option, _validate_type, - _verbose_safe_false, fill_doc, - logger, - pinv, ) -from .transformer import MNETransformerMixin +from ._covs_ged import _csp_estimate, _spoc_estimate +from ._mod_ged import _csp_mod, _spoc_mod +from .base import _GEDTransformer @fill_doc -class CSP(MNETransformerMixin, BaseEstimator): +class CSP(_GEDTransformer): """M/EEG signal decomposition using the Common Spatial Patterns (CSP). This class can be used as a supervised decomposition to estimate spatial @@ -68,6 +65,26 @@ class CSP(MNETransformerMixin, BaseEstimator): Parameters to pass to :func:`mne.compute_covariance`. .. versionadded:: 0.16 + + restr_type : "restricting" | "whitening" | None + Restricting transformation for covariance matrices before performing + generalized eigendecomposition. + If "restricting" only restriction to the principal subspace of signal_cov + will be performed. + If "whitening", covariance matrices will be additionally rescaled according + to the whitening for the signal_cov. + If None, no restriction will be applied. Defaults to "restricting". + + .. versionadded:: 1.10 + info : mne.Info | None + The mne.Info object with information about the sensors and methods of + measurement used for covariance estimation and generalized + eigendecomposition. + If None, one channel type and no projections will be assumed and if + rank is dict, it will be sum of ranks per channel type. + Defaults to None. + + .. versionadded:: 1.10 %(rank_none)s .. versionadded:: 0.17 @@ -95,7 +112,7 @@ class CSP(MNETransformerMixin, BaseEstimator): See Also -------- - mne.preprocessing.Xdawn, SPoC + XdawnTransformer, SPoC, SSD References ---------- @@ -111,11 +128,14 @@ def __init__( transform_into="average_power", norm_trace=False, cov_method_params=None, + restr_type="restricting", + info=None, rank=None, component_order="mutual_info", ): # Init default CSP self.n_components = n_components + self.info = info self.rank = rank self.reg = reg self.cov_est = cov_est @@ -124,6 +144,25 @@ def __init__( self.norm_trace = norm_trace self.cov_method_params = cov_method_params self.component_order = component_order + self.restr_type = restr_type + + cov_callable = partial( + _csp_estimate, + reg=reg, + cov_method_params=cov_method_params, + cov_est=cov_est, + info=info, + rank=rank, + norm_trace=norm_trace, + ) + mod_ged_callable = partial(_csp_mod, evecs_order=component_order) + super().__init__( + n_components=n_components, + cov_callable=cov_callable, + mod_ged_callable=mod_ged_callable, + restr_type=restr_type, + R_func=sum, + ) def _validate_params(self, *, y): _validate_type(self.n_components, int, "n_components") @@ -153,6 +192,14 @@ def _validate_params(self, *, y): n_classes = len(self.classes_) if n_classes < 2: raise ValueError(f"n_classes must be >= 2, but got {n_classes} class") + elif n_classes > 2 and self.component_order == "alternate": + raise ValueError( + "component_order='alternate' requires two classes, but data contains " + f"{n_classes} classes; use component_order='mutual_info' instead." + ) + _validate_type(self.rank, (dict, None, str), "rank") + _validate_type(self.info, (Info, None), "info") + _validate_type(self.cov_method_params, (abc.Mapping, None), "cov_method_params") def fit(self, X, y): """Estimate the CSP decomposition on epochs. @@ -171,26 +218,10 @@ def fit(self, X, y): """ X, y = self._check_data(X, y=y, fit=True, return_y=True) self._validate_params(y=y) - n_classes = len(self.classes_) - if n_classes > 2 and self.component_order == "alternate": - raise ValueError( - "component_order='alternate' requires two classes, but data contains " - f"{n_classes} classes; use component_order='mutual_info' instead." - ) - - # Convert rank to one that will run - _validate_type(self.rank, (dict, None, str), "rank") - - covs, sample_weights = self._compute_covariance_matrices(X, y) - eigen_vectors, eigen_values = self._decompose_covs(covs, sample_weights) - ix = self._order_components( - covs, sample_weights, eigen_vectors, eigen_values, self.component_order - ) - - eigen_vectors = eigen_vectors[:, ix] - self.filters_ = eigen_vectors.T - self.patterns_ = pinv(eigen_vectors) + # Covariance estimation, GED/AJD + # and evecs/evals sorting happen here + super().fit(X, y) pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) @@ -220,11 +251,8 @@ def transform(self, X): If self.transform_into == 'csp_space' then returns the data in CSP space and shape is (n_epochs, n_components, n_times). """ - check_is_fitted(self, "filters_") X = self._check_data(X) - pick_filters = self.filters_[: self.n_components] - X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) - + X = super().transform(X) # compute features (mean band power) if self.transform_into == "average_power": X = (X**2).mean(axis=2) @@ -543,162 +571,6 @@ def plot_filters( ) return fig - def _compute_covariance_matrices(self, X, y): - _, n_channels, _ = X.shape - - if self.cov_est == "concat": - cov_estimator = self._concat_cov - elif self.cov_est == "epoch": - cov_estimator = self._epoch_cov - - # Someday we could allow the user to pass this, then we wouldn't need to convert - # but in the meantime they can use a pipeline with a scaler - self._info = create_info(n_channels, 1000.0, "mag") - if isinstance(self.rank, dict): - self._rank = {"mag": sum(self.rank.values())} - else: - self._rank = _compute_rank_raw_array( - X.transpose(1, 0, 2).reshape(X.shape[1], -1), - self._info, - rank=self.rank, - scalings=None, - log_ch_type="data", - ) - - covs = [] - sample_weights = [] - for ci, this_class in enumerate(self.classes_): - cov, weight = cov_estimator( - X[y == this_class], - cov_kind=f"class={this_class}", - log_rank=ci == 0, - ) - - if self.norm_trace: - cov /= np.trace(cov) - - covs.append(cov) - sample_weights.append(weight) - - return np.stack(covs), np.array(sample_weights) - - def _concat_cov(self, x_class, *, cov_kind, log_rank): - """Concatenate epochs before computing the covariance.""" - _, n_channels, _ = x_class.shape - - x_class = x_class.transpose(1, 0, 2).reshape(n_channels, -1) - cov = _regularized_covariance( - x_class, - reg=self.reg, - method_params=self.cov_method_params, - rank=self._rank, - info=self._info, - cov_kind=cov_kind, - log_rank=log_rank, - log_ch_type="data", - ) - weight = x_class.shape[0] - - return cov, weight - - def _epoch_cov(self, x_class, *, cov_kind, log_rank): - """Mean of per-epoch covariances.""" - name = self.reg if isinstance(self.reg, str) else "empirical" - name += " with shrinkage" if isinstance(self.reg, float) else "" - logger.info( - f"Estimating {cov_kind + (' ' if cov_kind else '')}" - f"covariance (average over epochs; {name.upper()})" - ) - cov = sum( - _regularized_covariance( - this_X, - reg=self.reg, - method_params=self.cov_method_params, - rank=self._rank, - info=self._info, - cov_kind=cov_kind, - log_rank=log_rank and ii == 0, - log_ch_type="data", - verbose=_verbose_safe_false(), - ) - for ii, this_X in enumerate(x_class) - ) - cov /= len(x_class) - weight = len(x_class) - - return cov, weight - - def _decompose_covs(self, covs, sample_weights): - n_classes = len(covs) - n_channels = covs[0].shape[0] - assert self._rank is not None # should happen in _compute_covariance_matrices - _, sub_vec, mask = _smart_eigh( - covs.mean(0), - self._info, - self._rank, - proj_subspace=True, - do_compute_rank=False, - log_ch_type="data", - verbose=_verbose_safe_false(), - ) - sub_vec = sub_vec[mask] - covs = np.array([sub_vec @ cov @ sub_vec.T for cov in covs], float) - assert covs[0].shape == (mask.sum(),) * 2 - if n_classes == 2: - eigen_values, eigen_vectors = eigh(covs[0], covs.sum(0)) - else: - # The multiclass case is adapted from - # http://github.com/alexandrebarachant/pyRiemann - eigen_vectors, D = _ajd_pham(covs) - eigen_vectors = self._normalize_eigenvectors( - eigen_vectors.T, covs, sample_weights - ) - eigen_values = None - # project back - eigen_vectors = sub_vec.T @ eigen_vectors - assert eigen_vectors.shape == (n_channels, mask.sum()) - return eigen_vectors, eigen_values - - def _compute_mutual_info(self, covs, sample_weights, eigen_vectors): - class_probas = sample_weights / sample_weights.sum() - - mutual_info = [] - for jj in range(eigen_vectors.shape[1]): - aa, bb = 0, 0 - for cov, prob in zip(covs, class_probas): - tmp = np.dot(np.dot(eigen_vectors[:, jj].T, cov), eigen_vectors[:, jj]) - aa += prob * np.log(np.sqrt(tmp)) - bb += prob * (tmp**2 - 1) - mi = -(aa + (3.0 / 16) * (bb**2)) - mutual_info.append(mi) - - return mutual_info - - def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights): - # Here we apply an euclidean mean. See pyRiemann for other metrics - mean_cov = np.average(covs, axis=0, weights=sample_weights) - - for ii in range(eigen_vectors.shape[1]): - tmp = np.dot(np.dot(eigen_vectors[:, ii].T, mean_cov), eigen_vectors[:, ii]) - eigen_vectors[:, ii] /= np.sqrt(tmp) - return eigen_vectors - - def _order_components( - self, covs, sample_weights, eigen_vectors, eigen_values, component_order - ): - n_classes = len(self.classes_) - if component_order == "mutual_info" and n_classes > 2: - mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors) - ix = np.argsort(mutual_info)[::-1] - elif component_order == "mutual_info" and n_classes == 2: - ix = np.argsort(np.abs(eigen_values - 0.5))[::-1] - elif component_order == "alternate" and n_classes == 2: - i = np.argsort(eigen_values) - ix = np.empty_like(i) - ix[1::2] = i[: len(i) // 2] - ix[0::2] = i[len(i) // 2 :][::-1] - return ix - def _ajd_pham(X, eps=1e-6, max_iter=15): """Approximate joint diagonalization based on Pham's algorithm. @@ -821,6 +693,25 @@ class SPoC(CSP): Parameters to pass to :func:`mne.compute_covariance`. .. versionadded:: 0.16 + restr_type : "restricting" | "whitening" | None + Restricting transformation for covariance matrices before performing + generalized eigendecomposition. + If "restricting" only restriction to the principal subspace of signal_cov + will be performed. + If "whitening", covariance matrices will be additionally rescaled according + to the whitening for the signal_cov. + If None, no restriction will be applied. Defaults to None. + + .. versionadded:: 1.10 + info : mne.Info | None + The mne.Info object with information about the sensors and methods of + measurement used for covariance estimation and generalized + eigendecomposition. + If None, one channel type and no projections will be assumed and if + rank is dict, it will be sum of ranks per channel type. + Defaults to None. + + .. versionadded:: 1.10 %(rank_none)s .. versionadded:: 0.17 @@ -852,6 +743,8 @@ def __init__( log=None, transform_into="average_power", cov_method_params=None, + restr_type=None, + info=None, rank=None, ): """Init of SPoC.""" @@ -862,9 +755,26 @@ def __init__( cov_est="epoch", norm_trace=False, transform_into=transform_into, + restr_type=restr_type, + info=info, rank=rank, cov_method_params=cov_method_params, ) + + cov_callable = partial( + _spoc_estimate, + reg=reg, + cov_method_params=cov_method_params, + info=info, + rank=rank, + ) + super(CSP, self).__init__( + n_components=n_components, + cov_callable=cov_callable, + mod_ged_callable=_spoc_mod, + restr_type=restr_type, + ) + # Covariance estimation have to be done on the single epoch level, # unlike CSP where covariance estimation can also be achieved through # concatenation of all epochs from the same class. @@ -889,43 +799,7 @@ def fit(self, X, y): X, y = self._check_data(X, y=y, fit=True, return_y=True) self._validate_params(y=y) - # The following code is directly copied from pyRiemann - - # Normalize target variable - target = y.astype(np.float64) - target -= target.mean() - target /= target.std() - - n_epochs, n_channels = X.shape[:2] - - # Estimate single trial covariance - covs = np.empty((n_epochs, n_channels, n_channels)) - for ii, epoch in enumerate(X): - covs[ii] = _regularized_covariance( - epoch, - reg=self.reg, - method_params=self.cov_method_params, - rank=self.rank, - log_ch_type="data", - log_rank=ii == 0, - ) - - C = covs.mean(0) - Cz = np.mean(covs * target[:, np.newaxis, np.newaxis], axis=0) - - # solve eigenvalue decomposition - evals, evecs = eigh(Cz, C) - evals = evals.real - evecs = evecs.real - # sort vectors - ix = np.argsort(np.abs(evals))[::-1] - - # sort eigenvectors - evecs = evecs[:, ix].T - - # spatial patterns - self.patterns_ = pinv(evecs).T # n_channels x n_channels - self.filters_ = evecs # n_channels x n_channels + super(CSP, self).fit(X, y) pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 111ded9f274..23cc34a9865 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -2,30 +2,26 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import collections.abc as abc +from functools import partial + import numpy as np -from scipy.linalg import eigh -from sklearn.base import BaseEstimator -from sklearn.utils.validation import check_is_fitted from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx -from ..cov import Covariance, _regularized_covariance -from ..defaults import _handle_default from ..filter import filter_data -from ..rank import compute_rank -from ..time_frequency import psd_array_welch from ..utils import ( - _time_mask, _validate_type, - _verbose_safe_false, fill_doc, logger, ) -from .transformer import MNETransformerMixin +from ._covs_ged import _ssd_estimate +from ._mod_ged import _ssd_mod +from .base import _GEDTransformer @fill_doc -class SSD(MNETransformerMixin, BaseEstimator): +class SSD(_GEDTransformer): """ Signal decomposition using the Spatio-Spectral Decomposition (SSD). @@ -71,6 +67,17 @@ class SSD(MNETransformerMixin, BaseEstimator): cov_method_params : dict | None (default None) As in :class:`mne.decoding.SPoC` The default is None. + restr_type : "restricting" | "whitening" | "ssd" | None + Restricting transformation for covariance matrices before performing + generalized eigendecomposition. + If "restricting" only restriction to the principal subspace of signal_cov + will be performed. + If "whitening", covariance matrices will be additionally rescaled according + to the whitening for the signal_cov. + If "ssd", simplified version of "whitening" is performed. + If None, no restriction will be applied. Defaults to "ssd". + + .. versionadded:: 1.10 rank : None | dict | ‘info’ | ‘full’ As in :class:`mne.decoding.SPoC` This controls the rank computation that can be read from the @@ -81,9 +88,9 @@ class SSD(MNETransformerMixin, BaseEstimator): Attributes ---------- - filters_ : array, shape (n_channels, n_components) + filters_ : array, shape (``n_channels or less``, n_channels) The spatial filters to be multiplied with the signal. - patterns_ : array, shape (n_components, n_channels) + patterns_ : array, shape (``n_channels or less``, n_channels) The patterns for reconstructing the signal from the filtered data. References @@ -103,6 +110,7 @@ def __init__( return_filtered=False, n_fft=None, cov_method_params=None, + restr_type="whitening", rank=None, ): """Initialize instance.""" @@ -116,8 +124,28 @@ def __init__( self.return_filtered = return_filtered self.n_fft = n_fft self.cov_method_params = cov_method_params + self.restr_type = restr_type self.rank = rank + cov_callable = partial( + _ssd_estimate, + reg=reg, + cov_method_params=cov_method_params, + info=info, + picks=picks, + n_fft=n_fft, + filt_params_signal=filt_params_signal, + filt_params_noise=filt_params_noise, + rank=rank, + sort_by_spectral_ratio=sort_by_spectral_ratio, + ) + super().__init__( + n_components=n_components, + cov_callable=cov_callable, + mod_ged_callable=_ssd_mod, + restr_type=restr_type, + ) + def _validate_params(self, X): if isinstance(self.info, float): # special case, mostly for testing self.sfreq_ = self.info @@ -166,6 +194,7 @@ def _validate_params(self, X): "At this point SSD only supports fitting " f"single channel types. Your info has {len(ch_types)} types." ) + _validate_type(self.cov_method_params, (abc.Mapping, None), "cov_method_params") def _check_X(self, X, *, y=None, fit=False): """Check input data.""" @@ -202,52 +231,9 @@ def fit(self, X, y=None): else: info = create_info(X.shape[-2], self.sfreq_, ch_types="eeg") self.picks_ = _picks_to_idx(info, self.picks, none="data", exclude="bads") - X_aux = X[..., self.picks_, :] - X_signal = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) - X_noise = filter_data(X_aux, self.sfreq_, **self.filt_params_noise) - X_noise -= X_signal - if X.ndim == 3: - X_signal = np.hstack(X_signal) - X_noise = np.hstack(X_noise) - - # prevent rank change when computing cov with rank='full' - cov_signal = _regularized_covariance( - X_signal, - reg=self.reg, - method_params=self.cov_method_params, - rank="full", - info=info, - ) - cov_noise = _regularized_covariance( - X_noise, - reg=self.reg, - method_params=self.cov_method_params, - rank="full", - info=info, - ) - - # project cov to rank subspace - cov_signal, cov_noise, rank_proj = _dimensionality_reduction( - cov_signal, cov_noise, info, self.rank - ) + super().fit(X, y) - eigvals_, eigvects_ = eigh(cov_signal, cov_noise) - # sort in descending order - ix = np.argsort(eigvals_)[::-1] - self.eigvals_ = eigvals_[ix] - # project back to sensor space - self.filters_ = np.matmul(rank_proj, eigvects_[:, ix]) - self.patterns_ = np.linalg.pinv(self.filters_) - - # We assume that ordering by spectral ratio is more important - # than the initial ordering. This ordering should be also learned when - # fitting. - X_ssd = self.filters_.T @ X[..., self.picks_, :] - sorter_spec = slice(None) - if self.sort_by_spectral_ratio: - _, sorter_spec = self.get_spectral_ratio(ssd_sources=X_ssd) - self.sorter_spec_ = sorter_spec logger.info("Done.") return self @@ -266,13 +252,15 @@ def transform(self, X): X_ssd : array, shape ([n_epochs, ]n_components, n_times) The processed data. """ - check_is_fitted(self, "filters_") X = self._check_X(X) + # For the case where n_epochs dimension is absent. + if X.ndim == 2: + X = np.expand_dims(X, axis=0) + X_aux = X[..., self.picks_, :] if self.return_filtered: - X_aux = X[..., self.picks_, :] - X = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) - X_ssd = self.filters_.T @ X[..., self.picks_, :] - X_ssd = X_ssd[..., self.sorter_spec_, :][..., : self.n_components, :] + X_aux = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) + X_ssd = super().transform(X_aux).squeeze() + return X_ssd def fit_transform(self, X, y=None, **fit_params): @@ -300,42 +288,6 @@ def fit_transform(self, X, y=None, **fit_params): # use parent TransformerMixin method but with custom docstring return super().fit_transform(X, y=y, **fit_params) - def get_spectral_ratio(self, ssd_sources): - """Get the spectal signal-to-noise ratio for each spatial filter. - - Spectral ratio measure for best n_components selection - See :footcite:`NikulinEtAl2011`, Eq. (24). - - Parameters - ---------- - ssd_sources : array - Data projected to SSD space. - - Returns - ------- - spec_ratio : array, shape (n_channels) - Array with the sprectal ratio value for each component. - sorter_spec : array, shape (n_channels) - Array of indices for sorting spec_ratio. - - References - ---------- - .. footbibliography:: - """ - psd, freqs = psd_array_welch(ssd_sources, sfreq=self.sfreq_, n_fft=self.n_fft_) - sig_idx = _time_mask(freqs, *self.freqs_signal_) - noise_idx = _time_mask(freqs, *self.freqs_noise_) - if psd.ndim == 3: - mean_sig = psd[:, :, sig_idx].mean(axis=2).mean(axis=0) - mean_noise = psd[:, :, noise_idx].mean(axis=2).mean(axis=0) - spec_ratio = mean_sig / mean_noise - else: - mean_sig = psd[:, sig_idx].mean(axis=1) - mean_noise = psd[:, noise_idx].mean(axis=1) - spec_ratio = mean_sig / mean_noise - sorter_spec = spec_ratio.argsort()[::-1] - return spec_ratio, sorter_spec - def inverse_transform(self): """Not implemented yet.""" raise NotImplementedError("inverse_transform is not yet available.") @@ -364,68 +316,6 @@ def apply(self, X): The processed data. """ X_ssd = self.transform(X) - pick_patterns = self.patterns_[self.sorter_spec_][: self.n_components].T + pick_patterns = self.patterns_[: self.n_components].T X = pick_patterns @ X_ssd return X - - -def _dimensionality_reduction(cov_signal, cov_noise, info, rank): - """Perform dimensionality reduction on the covariance matrices.""" - n_channels = cov_signal.shape[0] - - # find ranks of covariance matrices - rank_signal = list( - compute_rank( - Covariance( - cov_signal, - info.ch_names, - list(), - list(), - 0, - verbose=_verbose_safe_false(), - ), - rank, - _handle_default("scalings_cov_rank", None), - info, - ).values() - )[0] - rank_noise = list( - compute_rank( - Covariance( - cov_noise, - info.ch_names, - list(), - list(), - 0, - verbose=_verbose_safe_false(), - ), - rank, - _handle_default("scalings_cov_rank", None), - info, - ).values() - )[0] - rank = np.min([rank_signal, rank_noise]) # should be identical - - if rank < n_channels: - eigvals, eigvects = eigh(cov_signal) - # sort in descending order - ix = np.argsort(eigvals)[::-1] - eigvals = eigvals[ix] - eigvects = eigvects[:, ix] - # compute rank subspace projection matrix - rank_proj = np.matmul( - eigvects[:, :rank], np.eye(rank) * (eigvals[:rank] ** -0.5) - ) - logger.info( - "Projecting covariance of %i channels to %i rank subspace", - n_channels, - rank, - ) - else: - rank_proj = np.eye(n_channels) - logger.info("Preserving covariance rank (%i)", rank) - - # project covariance matrices to rank subspace - cov_signal = np.matmul(rank_proj.T, np.matmul(cov_signal, rank_proj)) - cov_noise = np.matmul(rank_proj.T, np.matmul(cov_noise, rank_proj)) - return cov_signal, cov_noise, rank_proj diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py new file mode 100644 index 00000000000..5f413f81aee --- /dev/null +++ b/mne/decoding/tests/test_ged.py @@ -0,0 +1,388 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from functools import partial +from pathlib import Path + +import numpy as np +import pytest + +pytest.importorskip("sklearn") + + +from sklearn.model_selection import ParameterGrid +from sklearn.utils._testing import assert_allclose +from sklearn.utils.estimator_checks import parametrize_with_checks + +from mne import Epochs, compute_rank, create_info, pick_types, read_events +from mne._fiff.proj import make_eeg_average_ref_proj +from mne.cov import Covariance, _regularized_covariance +from mne.decoding._ged import ( + _get_restr_mat, + _handle_restr_mat, + _is_cov_pos_def, + _is_cov_symm_pos_semidef, + _smart_ajd, + _smart_ged, +) +from mne.decoding._mod_ged import _no_op_mod +from mne.decoding.base import _GEDTransformer +from mne.io import read_raw + +data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_dir / "test_raw.fif" +event_name = data_dir / "test-eve.fif" +tmin, tmax = -0.1, 0.2 +# if stop is too small pca may fail in some cases, but we're okay on this file +start, stop = 0, 8 + + +def _mock_info(n_channels): + info = create_info(n_channels, 1000.0, "eeg") + avg_eeg_projector = make_eeg_average_ref_proj(info=info, activate=False) + info["projs"].append(avg_eeg_projector) + return info + + +def _get_min_rank(covs, info): + min_rank = dict( + eeg=min( + list( + compute_rank( + Covariance( + cov, + info.ch_names, + list(), + list(), + 0, + # verbose=_verbose_safe_false(), + ), + rank=None, + # _handle_default("scalings_cov_rank", None), + info=info, + ).values() + )[0] + for cov in covs + ) + ) + return min_rank + + +def _mock_cov_callable(X, y, cov_method_params=None, compute_C_ref=True): + if cov_method_params is None: + cov_method_params = dict() + n_epochs, n_channels, n_times = X.shape + + # To pass sklearn check: + if n_channels == 1: + n_channels = 2 + X = np.tile(X, (1, n_channels, 1)) + + # To make covariance estimation sensible + if n_times == 1: + n_times = n_channels + X = np.tile(X, (1, 1, n_channels)) + + classes = np.unique(y) + covs, sample_weights = list(), list() + for ci, this_class in enumerate(classes): + class_data = X[y == this_class] + class_data = class_data.transpose(1, 0, 2).reshape(n_channels, -1) + cov = _regularized_covariance(class_data, **cov_method_params) + covs.append(cov) + sample_weights.append(class_data.shape[0]) + + ref_data = X.transpose(1, 0, 2).reshape(n_channels, -1) + if compute_C_ref: + C_ref = _regularized_covariance(ref_data, **cov_method_params) + else: + C_ref = None + info = _mock_info(n_channels) + rank = _get_min_rank(covs, info) + kwargs = dict() + + # To pass sklearn check: + if len(covs) == 1: + covs.append(covs[0]) + + elif len(covs) > 2: + kwargs["sample_weights"] = sample_weights + return covs, C_ref, info, rank, kwargs + + +def _mock_mod_ged_callable(evals, evecs, covs, **kwargs): + sorter = None + if evals is not None: + ix = np.argsort(evals)[::-1] + evals = evals[ix] + evecs = evecs[:, ix] + sorter = ix + return evals, evecs, sorter + + +param_grid = dict( + n_components=[4], + cov_callable=[partial(_mock_cov_callable, cov_method_params=dict(reg="empirical"))], + mod_ged_callable=[_mock_mod_ged_callable], + dec_type=["single", "multi"], + # XXX: Not covering "ssd" here because test_ssd.py works with 2D data. + # Need to fix its tests first. + restr_type=[ + "restricting", + "whitening", + ], + R_func=[partial(np.sum, axis=0)], +) + +ged_estimators = [_GEDTransformer(**p) for p in ParameterGrid(param_grid)] + + +@pytest.mark.slowtest +@parametrize_with_checks(ged_estimators) +def test_sklearn_compliance(estimator, check): + """Test GEDTransformer compliance with sklearn.""" + check(estimator) + + +def _get_X_y(event_id): + raw = read_raw(raw_fname, preload=False) + events = read_events(event_name) + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) + picks = picks[2:12:3] # subselect channels -> disable proj! + raw.add_proj([], remove_existing=True) + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + preload=True, + proj=False, + ) + X = epochs.get_data(copy=False, units=dict(eeg="uV", grad="fT/cm", mag="fT")) + y = epochs.events[:, -1] + return X, y + + +def test_ged_binary_cov(): + """Test GEDTransformer on audvis dataset with two covariances.""" + event_id = dict(aud_l=1, vis_l=3) + X, y = _get_X_y(event_id) + # Test "single" decomposition + covs, C_ref, info, rank, kwargs = _mock_cov_callable(X, y) + S, R = covs[0], covs[1] + restr_mat = _get_restr_mat(C_ref, info, rank) + evals, evecs = _smart_ged(S, R, restr_mat=restr_mat, R_func=None) + actual_evals, actual_evecs, sorter = _mock_mod_ged_callable( + evals, evecs, [S, R], **kwargs + ) + actual_filters = actual_evecs.T + + ged = _GEDTransformer( + n_components=4, + cov_callable=_mock_cov_callable, + mod_ged_callable=_mock_mod_ged_callable, + restr_type="restricting", + ) + ged.fit(X, y) + desired_evals = ged.evals_ + desired_filters = ged.filters_ + + assert_allclose(actual_evals, desired_evals) + assert_allclose(actual_filters, desired_filters) + + # Test "multi" decomposition (loop), restr_mat can be reused + all_evals, all_evecs = list(), list() + for i in range(len(covs)): + S = covs[i] + evals, evecs = _smart_ged(S, R, restr_mat) + evals, evecs, sorter = _mock_mod_ged_callable(evals, evecs, covs) + all_evals.append(evals) + all_evecs.append(evecs.T) + actual_evals = np.array(all_evals) + actual_filters = np.array(all_evecs) + + ged = _GEDTransformer( + n_components=4, + cov_callable=_mock_cov_callable, + mod_ged_callable=_mock_mod_ged_callable, + dec_type="multi", + restr_type="restricting", + ) + ged.fit(X, y) + desired_evals = ged.evals_ + desired_filters = ged.filters_ + + assert_allclose(actual_evals, desired_evals) + assert_allclose(actual_filters, desired_filters) + + assert ged._subset_multi_components(name="foo") is None + + +def test_ged_multicov(): + """Test GEDTransformer on audvis dataset with multiple covariances.""" + event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4) + X, y = _get_X_y(event_id) + # Test "single" decomposition for multicov (AJD) with C_ref + covs, C_ref, info, rank, kwargs = _mock_cov_callable(X, y) + restr_mat = _get_restr_mat(C_ref, info, rank) + evecs = _smart_ajd(covs, restr_mat=restr_mat) + evals = None + _, actual_evecs, _ = _mock_mod_ged_callable(evals, evecs, covs, **kwargs) + actual_filters = actual_evecs.T + + ged = _GEDTransformer( + n_components=4, + cov_callable=_mock_cov_callable, + mod_ged_callable=_mock_mod_ged_callable, + restr_type="restricting", + ) + ged.fit(X, y) + desired_filters = ged.filters_ + + assert_allclose(actual_filters, desired_filters) + + # Test "multi" decomposition for multicov (loop) + R = covs[-1] + all_evals, all_evecs = list(), list() + for i in range(len(covs)): + S = covs[i] + evals, evecs = _smart_ged(S, R, restr_mat) + evals, evecs, sorter = _mock_mod_ged_callable(evals, evecs, covs) + all_evals.append(evals) + all_evecs.append(evecs.T) + actual_evals = np.array(all_evals) + actual_filters = np.array(all_evecs) + + ged = _GEDTransformer( + n_components=4, + cov_callable=_mock_cov_callable, + mod_ged_callable=_mock_mod_ged_callable, + dec_type="multi", + restr_type="restricting", + ) + ged.fit(X, y) + desired_evals = ged.evals_ + desired_filters = ged.filters_ + + assert_allclose(actual_evals, desired_evals) + assert_allclose(actual_filters, desired_filters) + + # Test "single" decomposition for multicov (AJD) without C_ref + covs, C_ref, info, rank, kwargs = _mock_cov_callable( + X, y, cov_method_params=dict(reg="oas"), compute_C_ref=False + ) + covs = np.stack(covs) + evecs = _smart_ajd(covs, restr_mat=None) + evals = None + _, actual_evecs, _ = _mock_mod_ged_callable(evals, evecs, covs, **kwargs) + actual_filters = actual_evecs.T + + ged = _GEDTransformer( + n_components=4, + cov_callable=partial( + _mock_cov_callable, cov_method_params=dict(reg="oas"), compute_C_ref=False + ), + mod_ged_callable=_mock_mod_ged_callable, + restr_type="restricting", + ) + ged.fit(X, y) + desired_filters = ged.filters_ + + assert_allclose(actual_filters, desired_filters) + + +def test_ged_validation_raises(): + """Test GEDTransofmer validation raises correct errors.""" + event_id = dict(aud_l=1, vis_l=3) + X, y = _get_X_y(event_id) + + ged = _GEDTransformer( + n_components=-1, + cov_callable=_mock_cov_callable, + mod_ged_callable=_mock_mod_ged_callable, + restr_type="restricting", + ) + with pytest.raises(ValueError): + ged.fit(X, y) + + def _bad_cov_callable(X, y, foo): + return X, y, foo + + ged = _GEDTransformer( + n_components=1, + cov_callable=_bad_cov_callable, + mod_ged_callable=_mock_mod_ged_callable, + restr_type="restricting", + ) + with pytest.raises(ValueError): + ged.fit(X, y) + + +def test_ged_invalid_cov(): + """Test _validate_covariances raises proper errors.""" + ged = _GEDTransformer( + n_components=1, + cov_callable=_mock_cov_callable, + mod_ged_callable=_mock_mod_ged_callable, + ) + asymm_cov = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + with pytest.raises(ValueError, match="not symmetric"): + ged._validate_covariances([asymm_cov, None]) + + +def test__handle_restr_mat_invalid_restr_type(): + """Test _handle_restr_mat raises correct error when wrong restr_type.""" + C_ref = np.eye(3) + with pytest.raises(ValueError, match="restr_type"): + _handle_restr_mat(C_ref, restr_type="blah", info=None, rank=None) + + +def test_cov_validators(): + """Test that covariance validators indeed validate.""" + asymm = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + pos_def = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]]) + + assert not _is_cov_symm_pos_semidef(asymm) + assert _is_cov_symm_pos_semidef(sing_pos_semidef) + assert _is_cov_symm_pos_semidef(pos_def) + + assert not _is_cov_pos_def(asymm) + assert not _is_cov_pos_def(sing_pos_semidef) + assert _is_cov_pos_def(pos_def) + + +def test__is_cov_pos_def(): + """Test _is_cov_pos_def works.""" + asymm = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + pos_def = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]]) + assert not _is_cov_pos_def(asymm) + assert not _is_cov_pos_def(sing_pos_semidef) + assert _is_cov_pos_def(pos_def) + + +def test__smart_ajd_when_restr_mat_is_none(): + """Test _smart_ajd raises ValueError when restr_mat is None.""" + sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + pos_def1 = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]]) + pos_def2 = np.array([[10, 1, 2], [1, 12, 3], [2, 3, 15]]) + bad_covs = np.stack([sing_pos_semidef, pos_def1, pos_def2]) + with pytest.raises(ValueError, match="positive definite"): + _smart_ajd(bad_covs, restr_mat=None, weights=None) + + +def test__no_op_mod(): + """Test _no_op_mod returns the same evals/evecs objects.""" + evals = np.array([[1, 2], [3, 4]]) + evecs = np.array([0, 1]) + evals_no_op, evecs_no_op, sorter_no_op = _no_op_mod(evals, evecs) + assert evals is evals_no_op + assert evecs is evecs_no_op + assert sorter_no_op is None diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 20183776709..69ca232b111 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -3,6 +3,7 @@ # Copyright the MNE-Python contributors. import sys +from pathlib import Path import numpy as np import pytest @@ -13,8 +14,10 @@ from sklearn.pipeline import Pipeline from sklearn.utils.estimator_checks import parametrize_with_checks -from mne import create_info, io +from mne import Epochs, create_info, io, pick_types, read_events +from mne._fiff.pick import _picks_to_idx from mne.decoding import CSP +from mne.decoding._mod_ged import _get_spectral_ratio from mne.decoding.ssd import SSD from mne.filter import filter_data from mne.time_frequency import psd_array_welch @@ -22,6 +25,13 @@ freqs_sig = 9, 12 freqs_noise = 8, 13 +data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_dir / "test_raw.fif" +event_name = data_dir / "test-eve.fif" +tmin, tmax = -0.1, 0.2 +event_id = dict(aud_l=1, vis_l=3) +start, stop = 0, 8 + def simulate_data( freqs_sig=(9, 12), @@ -219,7 +229,9 @@ def test_ssd(): sort_by_spectral_ratio=False, ) ssd.fit(X) - spec_ratio, sorter_spec = ssd.get_spectral_ratio(ssd.transform(X)) + spec_ratio, sorter_spec = _get_spectral_ratio( + ssd.transform(X), ssd.sfreq_, ssd.n_fft_, ssd.freqs_signal_, ssd.freqs_noise_ + ) # since we now that the number of true components is 5, the relative # difference should be low for the first 5 components and then increases index_diff = np.argmax(-np.diff(spec_ratio)) @@ -302,8 +314,16 @@ def test_ssd_epoched_data(): ssd.fit(X) # Check if the 5 first 5 components are the same for both - _, sorter_spec_e = ssd_e.get_spectral_ratio(ssd_e.transform(X_e)) - _, sorter_spec = ssd.get_spectral_ratio(ssd.transform(X)) + _, sorter_spec_e = _get_spectral_ratio( + ssd_e.transform(X_e), + ssd_e.sfreq_, + ssd_e.n_fft_, + ssd_e.freqs_signal_, + ssd_e.freqs_noise_, + ) + _, sorter_spec = _get_spectral_ratio( + ssd.transform(X), ssd.sfreq_, ssd.n_fft_, ssd.freqs_signal_, ssd.freqs_noise_ + ) assert_array_equal( sorter_spec_e[:n_components_true], sorter_spec[:n_components_true] ) @@ -374,8 +394,12 @@ def test_sorting(): sort_by_spectral_ratio=False, ) ssd.fit(Xtr) - _, sorter_tr = ssd.get_spectral_ratio(ssd.transform(Xtr)) - _, sorter_te = ssd.get_spectral_ratio(ssd.transform(Xte)) + _, sorter_tr = _get_spectral_ratio( + ssd.transform(Xtr), ssd.sfreq_, ssd.n_fft_, ssd.freqs_signal_, ssd.freqs_noise_ + ) + _, sorter_te = _get_spectral_ratio( + ssd.transform(Xte), ssd.sfreq_, ssd.n_fft_, ssd.freqs_signal_, ssd.freqs_noise_ + ) assert any(sorter_tr != sorter_te) # check sort_by_spectral_ratio set to True @@ -389,7 +413,7 @@ def test_sorting(): ssd.fit(Xtr) # check sorters - sorter_in = ssd.sorter_spec_ + sorter_in = ssd.sorter_ ssd = SSD( info, filt_params_signal, @@ -398,7 +422,9 @@ def test_sorting(): sort_by_spectral_ratio=False, ) ssd.fit(Xtr) - _, sorter_out = ssd.get_spectral_ratio(ssd.transform(Xtr)) + _, sorter_out = _get_spectral_ratio( + ssd.transform(Xtr), ssd.sfreq_, ssd.n_fft_, ssd.freqs_signal_, ssd.freqs_noise_ + ) assert all(sorter_in == sorter_out) @@ -486,6 +512,64 @@ def test_non_full_rank_data(): ssd.fit(X) +def test_picks_arg(): + """Test that picks argument works as expected.""" + raw = io.read_raw_fif(raw_fname, preload=False) + events = read_events(event_name) + picks = pick_types( + raw.info, meg=True, eeg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) + raw.add_proj([], remove_existing=True) + epochs = Epochs( + raw, + events, + event_id, + -0.1, + 1, + picks=picks, + baseline=(None, 0), + preload=True, + proj=False, + ) + X = epochs.get_data(copy=False) + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=3, + h_trans_bandwidth=3, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=3, + h_trans_bandwidth=3, + ) + picks = ["eeg"] + info = epochs.info + picks_idx = _picks_to_idx(info, picks) + + # Test when return_filtered is False + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + picks=picks_idx, + return_filtered=False, + ) + ssd.fit(X).transform(X) + + # Test when return_filtered is true + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + picks=picks_idx, + return_filtered=True, + n_fft=64, + ) + ssd.fit(X).transform(X) + + @pytest.mark.filterwarnings("ignore:.*invalid value encountered in divide.*") @pytest.mark.filterwarnings("ignore:.*is longer than.*") @parametrize_with_checks( @@ -510,4 +594,5 @@ def test_sklearn_compliance(estimator, check): ) if any(ignore in str(check) for ignore in ignores): return + check(estimator) diff --git a/mne/decoding/xdawn.py b/mne/decoding/xdawn.py new file mode 100644 index 00000000000..fb9d799ba3f --- /dev/null +++ b/mne/decoding/xdawn.py @@ -0,0 +1,206 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import collections.abc as abc +from functools import partial + +import numpy as np + +from .._fiff.meas_info import Info +from ..cov import Covariance +from ..decoding._covs_ged import _xdawn_estimate +from ..decoding._mod_ged import _xdawn_mod +from ..decoding.base import _GEDTransformer +from ..utils import _validate_type, fill_doc + + +@fill_doc +class XdawnTransformer(_GEDTransformer): + """Implementation of the Xdawn Algorithm compatible with scikit-learn. + + Xdawn is a spatial filtering method designed to improve the signal + to signal + noise ratio (SSNR) of the event related responses. Xdawn was + originally designed for P300 evoked potential by enhancing the target + response with respect to the non-target response. This implementation is a + generalization to any type of event related response. + + .. note:: XdawnTransformer does not correct for epochs overlap. To correct + overlaps see ``Xdawn``. + + Parameters + ---------- + n_components : int (default 2) + The number of components to decompose the signals. + reg : float | str | None (default None) + If not None (same as ``'empirical'``, default), allow + regularization for covariance estimation. + If float, shrinkage is used (0 <= shrinkage <= 1). + For str options, ``reg`` will be passed to ``method`` to + :func:`mne.compute_covariance`. + signal_cov : None | Covariance | array, shape (n_channels, n_channels) + The signal covariance used for whitening of the data. + if None, the covariance is estimated from the epochs signal. + cov_method_params : dict | None + Parameters to pass to :func:`mne.compute_covariance`. + + .. versionadded:: 0.16 + restr_type : "restricting" | "whitening" | None + Restricting transformation for covariance matrices before performing + generalized eigendecomposition. + If "restricting" only restriction to the principal subspace of signal_cov + will be performed. + If "whitening", covariance matrices will be additionally rescaled according + to the whitening for the signal_cov. + If None, no restriction will be applied. Defaults to None. + + .. versionadded:: 1.10 + info : mne.Info | None + The mne.Info object with information about the sensors and methods of + measurement used for covariance estimation and generalized + eigendecomposition. + If None, one channel type and no projections will be assumed and if + rank is dict, it will be sum of ranks per channel type. + Defaults to None. + + .. versionadded:: 1.10 + %(rank_full)s + + .. versionadded:: 1.10 + + Attributes + ---------- + classes_ : array, shape (n_classes) + The event indices of the classes. + filters_ : array, shape (n_channels, n_channels) + The Xdawn components used to decompose the data for each event type. + patterns_ : array, shape (n_channels, n_channels) + The Xdawn patterns used to restore the signals for each event type. + + See Also + -------- + CSP, SPoC, SSD + """ + + def __init__( + self, + n_components=2, + reg=None, + signal_cov=None, + cov_method_params=None, + restr_type=None, + info=None, + rank="full", + ): + self.n_components = n_components + self.signal_cov = signal_cov + self.reg = reg + self.cov_method_params = cov_method_params + self.restr_type = restr_type + self.info = info + self.rank = rank + + cov_callable = partial( + _xdawn_estimate, + reg=reg, + cov_method_params=cov_method_params, + R=signal_cov, + info=info, + rank=rank, + ) + super().__init__( + n_components=n_components, + cov_callable=cov_callable, + mod_ged_callable=_xdawn_mod, + dec_type="multi", + restr_type=restr_type, + ) + + def _validate_params(self, X): + _validate_type(self.n_components, int, "n_components") + + # reg is validated in _regularized_covariance + + if self.signal_cov is not None: + if isinstance(self.signal_cov, Covariance): + self.signal_cov = self.signal_cov.data + elif not isinstance(self.signal_cov, np.ndarray): + raise ValueError("signal_cov should be mne.Covariance or np.ndarray") + if not np.array_equal(self.signal_cov.shape, np.tile(X.shape[1], 2)): + raise ValueError( + "signal_cov data should be of shape (n_channels, n_channels)" + ) + _validate_type(self.cov_method_params, (abc.Mapping, None), "cov_method_params") + _validate_type(self.info, (Info, None), "info") + + def fit(self, X, y=None): + """Fit Xdawn spatial filters. + + Parameters + ---------- + X : array, shape (n_epochs, n_channels, n_samples) + The target data. + y : array, shape (n_epochs,) | None + The target labels. If None, Xdawn fit on the average evoked. + + Returns + ------- + self : Xdawn instance + The Xdawn instance. + """ + X, y = self._check_data(X, y=y, fit=True, return_y=True) + # For test purposes + if y is None: + y = np.ones(len(X)) + self._validate_params(X) + + super().fit(X, y) + + return self + + def transform(self, X): + """Transform data with spatial filters. + + Parameters + ---------- + X : array, shape (n_epochs, n_channels, n_samples) + The target data. + + Returns + ------- + X : array, shape (n_epochs, n_components * n_classes, n_samples) + The transformed data. + """ + X = self._check_data(X) + X = super().transform(X) + return X + + def inverse_transform(self, X): + """Remove selected components from the signal. + + Given the unmixing matrix, transform data, zero out components, + and inverse transform the data. This procedure will reconstruct + the signals from which the dynamics described by the excluded + components is subtracted. + + Parameters + ---------- + X : array, shape (n_epochs, n_components * n_classes, n_times) + The transformed data. + + Returns + ------- + X : array, shape (n_epochs, n_channels * n_classes, n_times) + The inverse transform data. + """ + # Check size + X = self._check_data(X, check_n_features=False) + n_epochs, n_comp, n_times = X.shape + if n_comp != (self.n_components * len(self.classes_)): + raise ValueError( + f"X must have {self.n_components * len(self.classes_)} components, " + f"got {n_comp} instead." + ) + pick_patterns = self._subset_multi_components(name="patterns") + # Transform + return np.dot(pick_patterns.T, X).transpose(1, 0, 2) diff --git a/mne/preprocessing/tests/test_xdawn.py b/mne/preprocessing/tests/test_xdawn.py index c30fd5dcfd9..47585112ba3 100644 --- a/mne/preprocessing/tests/test_xdawn.py +++ b/mne/preprocessing/tests/test_xdawn.py @@ -22,7 +22,8 @@ pytest.importorskip("sklearn") -from mne.preprocessing.xdawn import Xdawn, _XdawnTransformer # noqa: E402 +from mne.decoding.xdawn import XdawnTransformer +from mne.preprocessing.xdawn import Xdawn # noqa: E402 base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" @@ -236,7 +237,7 @@ def test_xdawn_regularization(): def test_XdawnTransformer(): - """Test _XdawnTransformer.""" + """Test XdawnTransformer.""" pytest.importorskip("sklearn") # Get data raw, events, picks = _get_data() @@ -255,40 +256,42 @@ def test_XdawnTransformer(): X = epochs._data y = epochs.events[:, -1] # Fit - xdt = _XdawnTransformer() + xdt = XdawnTransformer() xdt.fit(X, y) pytest.raises(ValueError, xdt.fit, X, y[1:]) pytest.raises(ValueError, xdt.fit, "foo") # Provide covariance object signal_cov = compute_raw_covariance(raw, picks=picks) - xdt = _XdawnTransformer(signal_cov=signal_cov) + xdt = XdawnTransformer(signal_cov=signal_cov) xdt.fit(X, y) # Provide ndarray signal_cov = np.eye(len(picks)) - xdt = _XdawnTransformer(signal_cov=signal_cov) + xdt = XdawnTransformer(signal_cov=signal_cov) xdt.fit(X, y) # Provide ndarray of bad shape signal_cov = np.eye(len(picks) - 1) - xdt = _XdawnTransformer(signal_cov=signal_cov) + xdt = XdawnTransformer(signal_cov=signal_cov) pytest.raises(ValueError, xdt.fit, X, y) # Provide another type signal_cov = 42 - xdt = _XdawnTransformer(signal_cov=signal_cov) + xdt = XdawnTransformer(signal_cov=signal_cov) pytest.raises(ValueError, xdt.fit, X, y) # Fit with y as None - xdt = _XdawnTransformer() + xdt = XdawnTransformer() xdt.fit(X) - # Compare xdawn and _XdawnTransformer + # Compare xdawn and XdawnTransformer xd = Xdawn(correct_overlap=False) xd.fit(epochs) - xdt = _XdawnTransformer() + xdt = XdawnTransformer() xdt.fit(X, y) + # Subset filters + xdt_filters = xdt._subset_multi_components() assert_array_almost_equal( - xd.filters_["cond2"][:2, :], xdt.filters_.reshape(2, 2, 8)[0] + xd.filters_["cond2"][:2, :], xdt_filters.reshape(2, 2, 8)[0] ) # Transform testing @@ -363,7 +366,7 @@ def test_xdawn_decoding_performance(): epochs, mixing_mat = _simulate_erplike_mixed_data(n_epochs=100) y = epochs.events[:, 2] - # results of Xdawn and _XdawnTransformer should match + # results of Xdawn and XdawnTransformer should match xdawn_pipe = make_pipeline( Xdawn(n_components=n_xdawn_comps), Vectorizer(), @@ -371,7 +374,7 @@ def test_xdawn_decoding_performance(): LogisticRegression(solver="liblinear"), ) xdawn_trans_pipe = make_pipeline( - _XdawnTransformer(n_components=n_xdawn_comps), + XdawnTransformer(n_components=n_xdawn_comps), Vectorizer(), MinMaxScaler(), LogisticRegression(solver="liblinear"), @@ -397,7 +400,8 @@ def test_xdawn_decoding_performance(): [comps[[0]] for comps in fitted_xdawn.patterns_.values()] ) else: - relev_patterns = fitted_xdawn.patterns_[::n_xdawn_comps] + pick_patterns = fitted_xdawn._subset_multi_components(name="patterns") + relev_patterns = pick_patterns[::n_xdawn_comps] for i in range(len(relev_patterns)): r, _ = stats.pearsonr(relev_patterns[i, :], mixing_mat[0, :]) diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py index 606b49370df..fd061323a36 100644 --- a/mne/preprocessing/xdawn.py +++ b/mne/preprocessing/xdawn.py @@ -7,7 +7,7 @@ from .._fiff.pick import _pick_data_channels, pick_info from ..cov import Covariance, _regularized_covariance -from ..decoding import BaseEstimator, TransformerMixin +from ..decoding.xdawn import XdawnTransformer from ..epochs import BaseEpochs from ..evoked import Evoked, EvokedArray from ..io import BaseRaw @@ -202,7 +202,7 @@ def _fit_xdawn( ) evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors evecs /= np.apply_along_axis(np.linalg.norm, 0, evecs) - _patterns = np.linalg.pinv(evecs.T) + _patterns = pinv(evecs.T) filters.append(evecs[:, :n_components].T) patterns.append(_patterns[:, :n_components].T) @@ -212,155 +212,7 @@ def _fit_xdawn( return filters, patterns, evokeds -class _XdawnTransformer(BaseEstimator, TransformerMixin): - """Implementation of the Xdawn Algorithm compatible with scikit-learn. - - Xdawn is a spatial filtering method designed to improve the signal - to signal + noise ratio (SSNR) of the event related responses. Xdawn was - originally designed for P300 evoked potential by enhancing the target - response with respect to the non-target response. This implementation is a - generalization to any type of event related response. - - .. note:: _XdawnTransformer does not correct for epochs overlap. To correct - overlaps see ``Xdawn``. - - Parameters - ---------- - n_components : int (default 2) - The number of components to decompose the signals. - reg : float | str | None (default None) - If not None (same as ``'empirical'``, default), allow - regularization for covariance estimation. - If float, shrinkage is used (0 <= shrinkage <= 1). - For str options, ``reg`` will be passed to ``method`` to - :func:`mne.compute_covariance`. - signal_cov : None | Covariance | array, shape (n_channels, n_channels) - The signal covariance used for whitening of the data. - if None, the covariance is estimated from the epochs signal. - method_params : dict | None - Parameters to pass to :func:`mne.compute_covariance`. - - .. versionadded:: 0.16 - - Attributes - ---------- - classes_ : array, shape (n_classes) - The event indices of the classes. - filters_ : array, shape (n_channels, n_channels) - The Xdawn components used to decompose the data for each event type. - patterns_ : array, shape (n_channels, n_channels) - The Xdawn patterns used to restore the signals for each event type. - """ - - def __init__(self, n_components=2, reg=None, signal_cov=None, method_params=None): - """Init.""" - self.n_components = n_components - self.signal_cov = signal_cov - self.reg = reg - self.method_params = method_params - - def fit(self, X, y=None): - """Fit Xdawn spatial filters. - - Parameters - ---------- - X : array, shape (n_epochs, n_channels, n_samples) - The target data. - y : array, shape (n_epochs,) | None - The target labels. If None, Xdawn fit on the average evoked. - - Returns - ------- - self : Xdawn instance - The Xdawn instance. - """ - X, y = self._check_Xy(X, y) - - # Main function - self.classes_ = np.unique(y) - self.filters_, self.patterns_, _ = _fit_xdawn( - X, - y, - n_components=self.n_components, - reg=self.reg, - signal_cov=self.signal_cov, - method_params=self.method_params, - ) - return self - - def transform(self, X): - """Transform data with spatial filters. - - Parameters - ---------- - X : array, shape (n_epochs, n_channels, n_samples) - The target data. - - Returns - ------- - X : array, shape (n_epochs, n_components * n_classes, n_samples) - The transformed data. - """ - X, _ = self._check_Xy(X) - - # Check size - if self.filters_.shape[1] != X.shape[1]: - raise ValueError( - f"X must have {self.filters_.shape[1]} channels, got {X.shape[1]} " - "instead." - ) - - # Transform - X = np.dot(self.filters_, X) - X = X.transpose((1, 0, 2)) - return X - - def inverse_transform(self, X): - """Remove selected components from the signal. - - Given the unmixing matrix, transform data, zero out components, - and inverse transform the data. This procedure will reconstruct - the signals from which the dynamics described by the excluded - components is subtracted. - - Parameters - ---------- - X : array, shape (n_epochs, n_components * n_classes, n_times) - The transformed data. - - Returns - ------- - X : array, shape (n_epochs, n_channels * n_classes, n_times) - The inverse transform data. - """ - # Check size - X, _ = self._check_Xy(X) - n_epochs, n_comp, n_times = X.shape - if n_comp != (self.n_components * len(self.classes_)): - raise ValueError( - f"X must have {self.n_components * len(self.classes_)} components, " - f"got {n_comp} instead." - ) - - # Transform - return np.dot(self.patterns_.T, X).transpose(1, 0, 2) - - def _check_Xy(self, X, y=None): - """Check X and y types and dimensions.""" - # Check data - if not isinstance(X, np.ndarray) or X.ndim != 3: - raise ValueError( - "X must be an array of shape (n_epochs, n_channels, n_samples)." - ) - if y is None: - y = np.ones(len(X)) - y = np.asarray(y) - if len(X) != len(y): - raise ValueError("X and y must have the same length") - return X, y - - -class Xdawn(_XdawnTransformer): +class Xdawn(XdawnTransformer): """Implementation of the Xdawn Algorithm. Xdawn :footcite:`RivetEtAl2009,RivetEtAl2011` is a spatial @@ -483,7 +335,7 @@ def fit(self, epochs, y=None): events=events, tmin=tmin, sfreq=sfreq, - method_params=self.method_params, + method_params=self.cov_method_params, info=use_info, ) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 5ae3be53a3d..07c47ab7f28 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3697,6 +3697,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ docdict["rank"] = _rank_base +docdict["rank_full"] = _rank_base + "\n The default is ``'full'``." docdict["rank_info"] = _rank_base + "\n The default is ``'info'``." docdict["rank_none"] = _rank_base + "\n The default is ``None``." diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index 9d0e215ee80..32ee1091131 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -43,6 +43,9 @@ _._more_tags _.multi_class _.preserves_dtype +_.one_d_labels +_.two_d_array +_.three_d_array deep # Backward compat or rarely used