Skip to content

ENH: Add GED transformer #13259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 84 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
632c819
assert_allclose for base ged for csp, spoc, ssd and xdawn
Genuster May 22, 2025
4f8b5fa
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster May 22, 2025
7c072d1
update _epoch_cov logging following merge
Genuster May 22, 2025
211d23f
add a few preliminary docstrings
Genuster May 22, 2025
a701e42
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster May 22, 2025
0d58c8d
bump rtol/atol for spoc
Genuster May 22, 2025
2a1c5cb
Add big sklearn compliance test
Genuster May 29, 2025
3e9c32c
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster May 29, 2025
6e8b3aa
add __sklearn_tags__ to vulture's whitelist
Genuster Jun 2, 2025
b2e24ea
calm vulture down per attribute
Genuster Jun 2, 2025
fbd585e
put the TransformerMixin back
Genuster Jun 2, 2025
a9d5390
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 2, 2025
d142bd0
fix validation of covariances
Genuster Jun 2, 2025
9329494
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 3, 2025
6796366
add gedtranformer tests with audvis dataset
Genuster Jun 3, 2025
7a291b1
fixes following Eric's comments
Genuster Jun 4, 2025
9b34bd3
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 4, 2025
7c867ec
document shapes
Genuster Jun 4, 2025
e1e8d6d
another small test for GEDtransformer
Genuster Jun 6, 2025
5edc6fa
change name of restricting map to restricting matrix
Genuster Jun 6, 2025
89fb141
a few more ged tests
Genuster Jun 6, 2025
3986c99
fix multiplication order in original SSD
Genuster Jun 6, 2025
11b038f
add assert_allclose to xdawn and csp transform methods.
Genuster Jun 6, 2025
25e1ae3
more ged tests
Genuster Jun 6, 2025
6bbc459
clean up _xdawn_estimate
Genuster Jun 10, 2025
99d297e
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 10, 2025
029691b
add _validate_params for _XdawnTransformer
Genuster Jun 10, 2025
f38ce7d
review suggestions
Genuster Jun 13, 2025
11c31f7
address Eric's suggestions
Genuster Jun 13, 2025
85fb50f
add default no op for mod_ged_callable
Genuster Jun 13, 2025
3c7df08
replace mod_params with partial as well
Genuster Jun 13, 2025
9c7c711
add ged entry in the implementation details
Genuster Jun 13, 2025
95544c5
add feature to perform GED in the principal subspace for xdawn
Genuster Jun 13, 2025
8755089
add option for CSP to select restr_type and provide info
Genuster Jun 13, 2025
87a2466
add restr_type for SCoP and SSD
Genuster Jun 13, 2025
969a73e
fix SSD's filters_ shape inconsistency
Genuster Jun 13, 2025
5266372
use mne's pinv in SSD and Xdawn instead of np.linalg.pinv
Genuster Jun 13, 2025
726c500
move mne.preprocessing._XdawnTransformer to decoding and make it public
Genuster Jun 13, 2025
8e8bf3f
fix docstring
Genuster Jun 13, 2025
a374546
fix some terminological imprecisions in the implementation details
Genuster Jun 13, 2025
226abf4
add parameter validation for gedtransformer
Genuster Jun 15, 2025
3da3266
slightly improve validation in csp and ssd
Genuster Jun 15, 2025
4f5d436
rename xdawntranformer's method_params to cov_method_params for consi…
Genuster Jun 15, 2025
5e12465
add picks test for ssd
Genuster Jun 20, 2025
2bfc931
make ssd store ordered filters instead of sorting in transform
Genuster Jun 23, 2025
3a0dd1a
add expected failure for sklearn compliance test
Genuster Jun 25, 2025
19d01bd
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 26, 2025
29da3ff
better solution for the previous fix
Genuster Jun 26, 2025
8d1e656
add temporary xfail for windows pip CIs
Genuster Jun 27, 2025
ffac4fd
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 27, 2025
3761ff4
another try
Genuster Jun 27, 2025
8903664
and another
Genuster Jun 27, 2025
60d7360
add sorter return for _mod_ged functions
Genuster Jun 27, 2025
f8b8d6b
(1) clean up csp and remove asserts
Genuster Jun 27, 2025
10224fc
(2) clean up spoc and remove asserts
Genuster Jun 27, 2025
ac077f7
(3) clean up xdawn, remove asserts and make it store all filters and …
Genuster Jun 27, 2025
e23448c
(4) clean up ssd and remove asserts
Genuster Jun 27, 2025
5c612cb
more tests
Genuster Jul 4, 2025
17fae36
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jul 4, 2025
3c782b3
replace ssd's old whitening with compute_whitener
Genuster Jul 4, 2025
7522910
make XdawnTransformer properly public
Genuster Jul 4, 2025
a57934b
add changelog entry
Genuster Jul 4, 2025
7951a43
fix xdawntransformer docstring
Genuster Jul 4, 2025
b3a378d
more docstring adventures
Genuster Jul 5, 2025
7b9ad11
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jul 5, 2025
268f54f
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jul 6, 2025
10afd78
fix docdict order
Genuster Jul 6, 2025
5049e8c
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jul 11, 2025
9ec6835
make all init arguments have default in ged transformer
Genuster Jul 11, 2025
99f6e9c
Merge branch 'main' into base-GED
larsoner Jul 11, 2025
9364b04
Merge branch 'main' into base-GED
larsoner Jul 11, 2025
2b690ea
temporarily skip the problematic test
Genuster Jul 12, 2025
0117163
unskip the test
Genuster Jul 12, 2025
b08cf81
fix unskip
Genuster Jul 12, 2025
e4e38c3
WIP: Test [actions ssh] [skip circle] [skip azp]
larsoner Jul 15, 2025
3fc96c0
WIP: Test [actions ssh] [skip circle] [skip azp]
larsoner Jul 15, 2025
4c1b8ff
WIP: Test [actions ssh] [skip circle] [skip azp]
larsoner Jul 15, 2025
0bd74a5
WIP: Test more [actions ssh] [skip circle] [skip azp]
larsoner Jul 15, 2025
67f498d
FIX: More [ci skip]
larsoner Jul 15, 2025
f2451ba
WIP: Tests
larsoner Jul 15, 2025
23c37cb
FIX: States
larsoner Jul 15, 2025
13821c7
Merge branch 'main' into base-GED
larsoner Jul 15, 2025
35f5749
update versions and add stars
Genuster Jul 15, 2025
d713bfe
Merge branch 'base-GED' of https://github.com/Genuster/mne-python int…
Genuster Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions mne/decoding/_covs_ged.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""Covariance estimation for GED transformers."""

# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import numpy as np

from .._fiff.meas_info import Info, create_info
from .._fiff.pick import _picks_to_idx
from ..cov import Covariance, _compute_rank_raw_array, _regularized_covariance
from ..defaults import _handle_default
from ..filter import filter_data
from ..rank import compute_rank
from ..utils import _verbose_safe_false, logger


def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info):
"""Concatenate epochs before computing the covariance."""
_, n_channels, _ = x_class.shape

x_class = x_class.transpose(1, 0, 2).reshape(n_channels, -1)
cov = _regularized_covariance(
x_class,
reg=reg,
method_params=cov_method_params,
rank=rank,
info=info,
cov_kind=cov_kind,
log_rank=log_rank,
log_ch_type="data",
)
weight = x_class.shape[0]

return cov, weight


def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info):
"""Mean of per-epoch covariances."""
name = reg if isinstance(reg, str) else "empirical"
name += " with shrinkage" if isinstance(reg, float) else ""
logger.info(
f"Estimating {cov_kind + (' ' if cov_kind else '')}"
f"covariance (average over epochs; {name.upper()})"
)
cov = sum(
_regularized_covariance(
this_X,
reg=reg,
method_params=cov_method_params,
rank=rank,
info=info,
cov_kind=cov_kind,
log_rank=log_rank and ii == 0,
log_ch_type="data",
verbose=_verbose_safe_false(),
)
for ii, this_X in enumerate(x_class)
)
cov /= len(x_class)
weight = len(x_class)

return cov, weight


def _csp_estimate(X, y, reg, cov_method_params, cov_est, rank, norm_trace):
_, n_channels, _ = X.shape
classes_ = np.unique(y)
if cov_est == "concat":
cov_estimator = _concat_cov
elif cov_est == "epoch":
cov_estimator = _epoch_cov
# Someday we could allow the user to pass this, then we wouldn't need to convert
# but in the meantime they can use a pipeline with a scaler
_info = create_info(n_channels, 1000.0, "mag")
if isinstance(rank, dict):
_rank = {"mag": sum(rank.values())}
else:
_rank = _compute_rank_raw_array(
X.transpose(1, 0, 2).reshape(X.shape[1], -1),
_info,
rank=rank,
scalings=None,
log_ch_type="data",
)

covs = []
sample_weights = []
for ci, this_class in enumerate(classes_):
cov, weight = cov_estimator(
X[y == this_class],
cov_kind=f"class={this_class}",
log_rank=ci == 0,
reg=reg,
cov_method_params=cov_method_params,
rank=_rank,
info=_info,
)

if norm_trace:
cov /= np.trace(cov)

covs.append(cov)
sample_weights.append(weight)

covs = np.stack(covs)
C_ref = covs.mean(0)

return covs, C_ref, _info, _rank, dict(sample_weights=np.array(sample_weights))


def _xdawn_estimate(
X,
y,
reg,
cov_method_params,
R=None,
events=None,
tmin=0,
sfreq=1,
info=None,
rank="full",
):
from ..preprocessing.xdawn import _least_square_evoked

if not isinstance(X, np.ndarray) or X.ndim != 3:
raise ValueError("X must be 3D ndarray")

classes = np.unique(y)

# XXX Eventually this could be made to deal with rank deficiency properly
# by exposing this "rank" parameter, but this will require refactoring
# the linalg.eigh call to operate in the lower-dimension
# subspace, then project back out.

# Retrieve or compute whitening covariance
if R is None:
R = _regularized_covariance(
np.hstack(X), reg, cov_method_params, info, rank=rank
)
elif isinstance(R, Covariance):
R = R.data
if not isinstance(R, np.ndarray) or (
not np.array_equal(R.shape, np.tile(X.shape[1], 2))
):
raise ValueError(
"R must be None, a covariance instance, "
"or an array of shape (n_chans, n_chans)"
)

# Get prototype events
if events is not None:
evokeds, toeplitzs = _least_square_evoked(X, events, tmin, sfreq)
else:
evokeds, toeplitzs = list(), list()
for c in classes:
# Prototyped response for each class
evokeds.append(np.mean(X[y == c, :, :], axis=0))
toeplitzs.append(1.0)

covs = []
for evo, toeplitz in zip(evokeds, toeplitzs):
# Estimate covariance matrix of the prototype response
evo = np.dot(evo, toeplitz)
evo_cov = _regularized_covariance(evo, reg, cov_method_params, info, rank=rank)
covs.append(evo_cov)

covs.append(R)
C_ref = None
rank = None
info = None
return covs, C_ref, info, rank, dict()


def _ssd_estimate(
X,
y,
reg,
cov_method_params,
info,
picks,
filt_params_signal,
filt_params_noise,
rank,
):
if isinstance(info, Info):
sfreq = info["sfreq"]
elif isinstance(info, float): # special case, mostly for testing
sfreq = info
info = create_info(X.shape[-2], sfreq, ch_types="eeg")
picks = _picks_to_idx(info, picks, none="data", exclude="bads")
X_aux = X[..., picks, :]
X_signal = filter_data(X_aux, sfreq, **filt_params_signal)
X_noise = filter_data(X_aux, sfreq, **filt_params_noise)
X_noise -= X_signal
if X.ndim == 3:
X_signal = np.hstack(X_signal)
X_noise = np.hstack(X_noise)

# prevent rank change when computing cov with rank='full'
S = _regularized_covariance(
X_signal,
reg=reg,
method_params=cov_method_params,
rank="full",
info=info,
)
R = _regularized_covariance(
X_noise,
reg=reg,
method_params=cov_method_params,
rank="full",
info=info,
)
covs = [S, R]
C_ref = S

all_ranks = list()
for cov in covs:
r = list(
compute_rank(
Covariance(
cov,
info.ch_names,
list(),
list(),
0,
verbose=_verbose_safe_false(),
),
rank,
_handle_default("scalings_cov_rank", None),
info,
).values()
)[0]
all_ranks.append(r)
rank = np.min(all_ranks)
return covs, C_ref, info, rank, dict()


def _spoc_estimate(X, y, reg, cov_method_params, rank):
# Normalize target variable
target = y.astype(np.float64)
target -= target.mean()
target /= target.std()

n_epochs, n_channels = X.shape[:2]

# Estimate single trial covariance
covs = np.empty((n_epochs, n_channels, n_channels))
for ii, epoch in enumerate(X):
covs[ii] = _regularized_covariance(
epoch,
reg=reg,
method_params=cov_method_params,
rank=rank,
log_ch_type="data",
log_rank=ii == 0,
)

S = np.mean(covs * target[:, np.newaxis, np.newaxis], axis=0)
R = covs.mean(0)

covs = [S, R]
C_ref = None
info = None
return covs, C_ref, info, rank, dict()
69 changes: 69 additions & 0 deletions mne/decoding/_mod_ged.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Eigenvalue eigenvector modifiers for GED transformers."""

# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import numpy as np


def _compute_mutual_info(covs, sample_weights, evecs):
class_probas = sample_weights / sample_weights.sum()

mutual_info = []
for jj in range(evecs.shape[1]):
aa, bb = 0, 0
for cov, prob in zip(covs, class_probas):
tmp = np.dot(np.dot(evecs[:, jj].T, cov), evecs[:, jj])
aa += prob * np.log(np.sqrt(tmp))
bb += prob * (tmp**2 - 1)
mi = -(aa + (3.0 / 16) * (bb**2))
mutual_info.append(mi)

return mutual_info


def _csp_mod(evals, evecs, covs, evecs_order, sample_weights):
n_classes = sample_weights.shape[0]
if evecs_order == "mutual_info" and n_classes > 2:
mutual_info = _compute_mutual_info(covs, sample_weights, evecs)
ix = np.argsort(mutual_info)[::-1]
elif evecs_order == "mutual_info" and n_classes == 2:
ix = np.argsort(np.abs(evals - 0.5))[::-1]
elif evecs_order == "alternate" and n_classes == 2:
i = np.argsort(evals)
ix = np.empty_like(i)
ix[1::2] = i[: len(i) // 2]
ix[0::2] = i[len(i) // 2 :][::-1]
if evals is not None:
evals = evals[ix]
evecs = evecs[:, ix]
return evals, evecs


def _xdawn_mod(evals, evecs, covs=None):
evals, evecs = _sort_descending(evals, evecs)
evecs /= np.linalg.norm(evecs, axis=0)
return evals, evecs


def _ssd_mod(evals, evecs, covs=None):
evals, evecs = _sort_descending(evals, evecs)
return evals, evecs


def _spoc_mod(evals, evecs, covs=None):
evals = evals.real
evecs = evecs.real
evals, evecs = _sort_descending(evals, evecs, by_abs=True)
return evals, evecs


def _sort_descending(evals, evecs, by_abs=False):
if by_abs:
ix = np.argsort(np.abs(evals))[::-1]
else:
ix = np.argsort(evals)[::-1]
evals = evals[ix]
evecs = evecs[:, ix]
return evals, evecs
Loading