Skip to content

ENH: Add GED transformer #13259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 84 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
632c819
assert_allclose for base ged for csp, spoc, ssd and xdawn
Genuster May 22, 2025
4f8b5fa
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster May 22, 2025
7c072d1
update _epoch_cov logging following merge
Genuster May 22, 2025
211d23f
add a few preliminary docstrings
Genuster May 22, 2025
a701e42
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster May 22, 2025
0d58c8d
bump rtol/atol for spoc
Genuster May 22, 2025
2a1c5cb
Add big sklearn compliance test
Genuster May 29, 2025
3e9c32c
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster May 29, 2025
6e8b3aa
add __sklearn_tags__ to vulture's whitelist
Genuster Jun 2, 2025
b2e24ea
calm vulture down per attribute
Genuster Jun 2, 2025
fbd585e
put the TransformerMixin back
Genuster Jun 2, 2025
a9d5390
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 2, 2025
d142bd0
fix validation of covariances
Genuster Jun 2, 2025
9329494
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 3, 2025
6796366
add gedtranformer tests with audvis dataset
Genuster Jun 3, 2025
7a291b1
fixes following Eric's comments
Genuster Jun 4, 2025
9b34bd3
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 4, 2025
7c867ec
document shapes
Genuster Jun 4, 2025
e1e8d6d
another small test for GEDtransformer
Genuster Jun 6, 2025
5edc6fa
change name of restricting map to restricting matrix
Genuster Jun 6, 2025
89fb141
a few more ged tests
Genuster Jun 6, 2025
3986c99
fix multiplication order in original SSD
Genuster Jun 6, 2025
11b038f
add assert_allclose to xdawn and csp transform methods.
Genuster Jun 6, 2025
25e1ae3
more ged tests
Genuster Jun 6, 2025
6bbc459
clean up _xdawn_estimate
Genuster Jun 10, 2025
99d297e
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 10, 2025
029691b
add _validate_params for _XdawnTransformer
Genuster Jun 10, 2025
f38ce7d
review suggestions
Genuster Jun 13, 2025
11c31f7
address Eric's suggestions
Genuster Jun 13, 2025
85fb50f
add default no op for mod_ged_callable
Genuster Jun 13, 2025
3c7df08
replace mod_params with partial as well
Genuster Jun 13, 2025
9c7c711
add ged entry in the implementation details
Genuster Jun 13, 2025
95544c5
add feature to perform GED in the principal subspace for xdawn
Genuster Jun 13, 2025
8755089
add option for CSP to select restr_type and provide info
Genuster Jun 13, 2025
87a2466
add restr_type for SCoP and SSD
Genuster Jun 13, 2025
969a73e
fix SSD's filters_ shape inconsistency
Genuster Jun 13, 2025
5266372
use mne's pinv in SSD and Xdawn instead of np.linalg.pinv
Genuster Jun 13, 2025
726c500
move mne.preprocessing._XdawnTransformer to decoding and make it public
Genuster Jun 13, 2025
8e8bf3f
fix docstring
Genuster Jun 13, 2025
a374546
fix some terminological imprecisions in the implementation details
Genuster Jun 13, 2025
226abf4
add parameter validation for gedtransformer
Genuster Jun 15, 2025
3da3266
slightly improve validation in csp and ssd
Genuster Jun 15, 2025
4f5d436
rename xdawntranformer's method_params to cov_method_params for consi…
Genuster Jun 15, 2025
5e12465
add picks test for ssd
Genuster Jun 20, 2025
2bfc931
make ssd store ordered filters instead of sorting in transform
Genuster Jun 23, 2025
3a0dd1a
add expected failure for sklearn compliance test
Genuster Jun 25, 2025
19d01bd
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 26, 2025
29da3ff
better solution for the previous fix
Genuster Jun 26, 2025
8d1e656
add temporary xfail for windows pip CIs
Genuster Jun 27, 2025
ffac4fd
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 27, 2025
3761ff4
another try
Genuster Jun 27, 2025
8903664
and another
Genuster Jun 27, 2025
60d7360
add sorter return for _mod_ged functions
Genuster Jun 27, 2025
f8b8d6b
(1) clean up csp and remove asserts
Genuster Jun 27, 2025
10224fc
(2) clean up spoc and remove asserts
Genuster Jun 27, 2025
ac077f7
(3) clean up xdawn, remove asserts and make it store all filters and …
Genuster Jun 27, 2025
e23448c
(4) clean up ssd and remove asserts
Genuster Jun 27, 2025
5c612cb
more tests
Genuster Jul 4, 2025
17fae36
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jul 4, 2025
3c782b3
replace ssd's old whitening with compute_whitener
Genuster Jul 4, 2025
7522910
make XdawnTransformer properly public
Genuster Jul 4, 2025
a57934b
add changelog entry
Genuster Jul 4, 2025
7951a43
fix xdawntransformer docstring
Genuster Jul 4, 2025
b3a378d
more docstring adventures
Genuster Jul 5, 2025
7b9ad11
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jul 5, 2025
268f54f
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jul 6, 2025
10afd78
fix docdict order
Genuster Jul 6, 2025
5049e8c
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jul 11, 2025
9ec6835
make all init arguments have default in ged transformer
Genuster Jul 11, 2025
99f6e9c
Merge branch 'main' into base-GED
larsoner Jul 11, 2025
9364b04
Merge branch 'main' into base-GED
larsoner Jul 11, 2025
2b690ea
temporarily skip the problematic test
Genuster Jul 12, 2025
0117163
unskip the test
Genuster Jul 12, 2025
b08cf81
fix unskip
Genuster Jul 12, 2025
e4e38c3
WIP: Test [actions ssh] [skip circle] [skip azp]
larsoner Jul 15, 2025
3fc96c0
WIP: Test [actions ssh] [skip circle] [skip azp]
larsoner Jul 15, 2025
4c1b8ff
WIP: Test [actions ssh] [skip circle] [skip azp]
larsoner Jul 15, 2025
0bd74a5
WIP: Test more [actions ssh] [skip circle] [skip azp]
larsoner Jul 15, 2025
67f498d
FIX: More [ci skip]
larsoner Jul 15, 2025
f2451ba
WIP: Tests
larsoner Jul 15, 2025
23c37cb
FIX: States
larsoner Jul 15, 2025
13821c7
Merge branch 'main' into base-GED
larsoner Jul 15, 2025
35f5749
update versions and add stars
Genuster Jul 15, 2025
d713bfe
Merge branch 'base-GED' of https://github.com/Genuster/mne-python int…
Genuster Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 198 additions & 1 deletion mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numbers

import numpy as np
import scipy.linalg
from sklearn import model_selection as models
from sklearn.base import ( # noqa: F401
BaseEstimator,
Expand All @@ -20,9 +21,205 @@
from sklearn.metrics import check_scoring
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.utils import check_array, check_X_y, indexable
from sklearn.utils.validation import check_is_fitted

from ..parallel import parallel_func
from ..utils import _pl, logger, verbose, warn
from ..utils import _pl, logger, pinv, verbose, warn
from .ged import _handle_restr_map, _smart_ajd, _smart_ged
from .transformer import MNETransformerMixin


class GEDTransformer(MNETransformerMixin, BaseEstimator):
"""M/EEG signal decomposition using the generalized eigenvalue decomposition (GED).

Given two channel covariance matrices S and R, the goal is to find spatial filters
that maximise contrast between S and R.

Parameters
----------
n_filters : int
The number of spatial filters to decompose M/EEG signals.
cov_callable : callable
Function used to estimate covariances and reference matrix (C_ref) from the
data.
cov_params : dict
Parameters passed to cov_callable.
mod_ged_callable : callable
Function used to modify (e.g. sort or normalize) generalized
eigenvalues and eigenvectors.
mod_params : dict
Parameters passed to mod_ged_callable.
dec_type : "single" | "multi"
When "single" and cov_callable returns > 2 covariances,
approximate joint diagonalization based on Pham's algorithm
will be used instead of GED.
When 'multi', GED is performed separately for each class, i.e. each covariance
(except the last) returned by cov_callable is decomposed with the last
covariance. In this case, number of covariances should be number of classes + 1.
Defaults to "single".
restr_type : "restricting" | "whitening" | "ssd" | None
Restricting transformation for covariance matrices before performing GED.
If "restricting" only restriction to the principal subspace of the C_ref
will be performed.
If "whitening", covariance matrices will be additionally rescaled according
to the whitening for the C_ref.
If "ssd", perform simplified version of "whitening",
preserved for compatibility.
If None, no restriction will be applied. Defaults to None.
R_func : callable | None
If provided GED will be performed on (S, R_func(S,R)).

Attributes
----------
evals_ : ndarray, shape (n_channels)
If fit, generalized eigenvalues used to decompose S and R, else None.
filters_ : ndarray, shape (n_channels or less, n_channels)
If fit, spatial filters (unmixing matrix) used to decompose the data,
else None.
patterns_ : ndarray, shape (n_channels or less, n_channels)
If fit, spatial patterns (mixing matrix) used to restore M/EEG signals,
else None.

See Also
--------
CSP
SPoC
SSD
mne.preprocessing.Xdawn
"""

def __init__(
self,
n_filters,
cov_callable,
cov_params,
mod_ged_callable,
mod_params,
dec_type="single",
restr_type=None,
R_func=None,
):
self.n_filters = n_filters
self.cov_callable = cov_callable
self.cov_params = cov_params
self.mod_ged_callable = mod_ged_callable
self.mod_params = mod_params
self.dec_type = dec_type
self.restr_type = restr_type
self.R_func = R_func

def fit(self, X, y=None):
"""..."""
X, y = self._check_data(
X,
y=y,
fit=True,
return_y=True,
atleast_3d=False if self.restr_type == "ssd" else True,
)
covs, C_ref, info, rank, kwargs = self.cov_callable(X, y, **self.cov_params)
covs = np.stack(covs)
self._validate_covariances(covs)
self._validate_covariances([C_ref])
if self.dec_type == "single":
if len(covs) > 2:
sample_weights = kwargs["sample_weights"]
restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank)
evecs = _smart_ajd(covs, restr_map, weights=sample_weights)
evals = None
else:
S = covs[0]
R = covs[1]
if self.restr_type == "ssd":
mult_order = "ssd"
else:
mult_order = None
restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank)
evals, evecs = _smart_ged(
S, R, restr_map, R_func=self.R_func, mult_order=mult_order
)

evals, evecs = self.mod_ged_callable(
evals, evecs, covs, **self.mod_params, **kwargs
)
self.evals_ = evals
self.filters_ = evecs.T
if self.restr_type == "ssd":
self.patterns_ = np.linalg.pinv(evecs)
else:
self.patterns_ = pinv(evecs)

elif self.dec_type == "multi":
self.classes_ = np.unique(y)
R = covs[-1]
if self.restr_type == "ssd":
mult_order = "ssd"
else:
mult_order = None
restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank)
all_evals, all_evecs, all_patterns = list(), list(), list()
for i in range(len(self.classes_)):
S = covs[i]

evals, evecs = _smart_ged(
S, R, restr_map, R_func=self.R_func, mult_order=mult_order
)

evals, evecs = self.mod_ged_callable(
evals, evecs, covs, **self.mod_params, **kwargs
)
all_evals.append(evals)
all_evecs.append(evecs.T)
all_patterns.append(np.linalg.pinv(evecs))
self.evals_ = np.array(all_evals)
self.filters_ = np.array(all_evecs)
self.patterns_ = np.array(all_patterns)

return self

def transform(self, X):
"""..."""
check_is_fitted(self, "filters_")
X = self._check_data(X)
if self.dec_type == "single":
pick_filters = self.filters_[: self.n_filters]
elif self.dec_type == "multi":
pick_filters = np.concatenate(
[
self.filters_[i, : self.n_filters]
for i in range(self.filters_.shape[0])
],
axis=0,
)
X = np.asarray([pick_filters @ epoch for epoch in X])
return X

def _validate_covariances(self, covs):
for cov in covs:
if cov is None:
continue
is_sym = scipy.linalg.issymmetric(cov, rtol=1e-10, atol=1e-11)
if not is_sym:
raise ValueError(
"One of covariances or C_ref is not symmetric, "
"check your cov_callable"
)
if not np.all(np.linalg.eigvals(cov) >= 0):
ValueError(
"One of covariances or C_ref has negative eigenvalues, "
"check your cov_callable"
)

def __sklearn_tags__(self):
"""Tag the transformer."""
tags = super().__sklearn_tags__()
tags.estimator_type = "transformer"
# Can be a transformer where S and R covs are not based on y classes.
tags.target_tags.required = False
tags.target_tags.one_d_labels = True
tags.input_tags.two_d_array = True
tags.input_tags.three_d_array = True
return tags


class LinearModel(MetaEstimatorMixin, BaseEstimator):
Expand Down
Loading