-
Notifications
You must be signed in to change notification settings - Fork 1.4k
ENH: Add GED transformer #13259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
ENH: Add GED transformer #13259
Changes from 17 commits
Commits
Show all changes
84 commits
Select commit
Hold shift + click to select a range
632c819
assert_allclose for base ged for csp, spoc, ssd and xdawn
Genuster 4f8b5fa
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 7c072d1
update _epoch_cov logging following merge
Genuster 211d23f
add a few preliminary docstrings
Genuster a701e42
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 0d58c8d
bump rtol/atol for spoc
Genuster 2a1c5cb
Add big sklearn compliance test
Genuster 3e9c32c
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 6e8b3aa
add __sklearn_tags__ to vulture's whitelist
Genuster b2e24ea
calm vulture down per attribute
Genuster fbd585e
put the TransformerMixin back
Genuster a9d5390
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster d142bd0
fix validation of covariances
Genuster 9329494
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 6796366
add gedtranformer tests with audvis dataset
Genuster 7a291b1
fixes following Eric's comments
Genuster 9b34bd3
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 7c867ec
document shapes
Genuster e1e8d6d
another small test for GEDtransformer
Genuster 5edc6fa
change name of restricting map to restricting matrix
Genuster 89fb141
a few more ged tests
Genuster 3986c99
fix multiplication order in original SSD
Genuster 11b038f
add assert_allclose to xdawn and csp transform methods.
Genuster 25e1ae3
more ged tests
Genuster 6bbc459
clean up _xdawn_estimate
Genuster 99d297e
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 029691b
add _validate_params for _XdawnTransformer
Genuster f38ce7d
review suggestions
Genuster 11c31f7
address Eric's suggestions
Genuster 85fb50f
add default no op for mod_ged_callable
Genuster 3c7df08
replace mod_params with partial as well
Genuster 9c7c711
add ged entry in the implementation details
Genuster 95544c5
add feature to perform GED in the principal subspace for xdawn
Genuster 8755089
add option for CSP to select restr_type and provide info
Genuster 87a2466
add restr_type for SCoP and SSD
Genuster 969a73e
fix SSD's filters_ shape inconsistency
Genuster 5266372
use mne's pinv in SSD and Xdawn instead of np.linalg.pinv
Genuster 726c500
move mne.preprocessing._XdawnTransformer to decoding and make it public
Genuster 8e8bf3f
fix docstring
Genuster a374546
fix some terminological imprecisions in the implementation details
Genuster 226abf4
add parameter validation for gedtransformer
Genuster 3da3266
slightly improve validation in csp and ssd
Genuster 4f5d436
rename xdawntranformer's method_params to cov_method_params for consi…
Genuster 5e12465
add picks test for ssd
Genuster 2bfc931
make ssd store ordered filters instead of sorting in transform
Genuster 3a0dd1a
add expected failure for sklearn compliance test
Genuster 19d01bd
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 29da3ff
better solution for the previous fix
Genuster 8d1e656
add temporary xfail for windows pip CIs
Genuster ffac4fd
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 3761ff4
another try
Genuster 8903664
and another
Genuster 60d7360
add sorter return for _mod_ged functions
Genuster f8b8d6b
(1) clean up csp and remove asserts
Genuster 10224fc
(2) clean up spoc and remove asserts
Genuster ac077f7
(3) clean up xdawn, remove asserts and make it store all filters and …
Genuster e23448c
(4) clean up ssd and remove asserts
Genuster 5c612cb
more tests
Genuster 17fae36
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 3c782b3
replace ssd's old whitening with compute_whitener
Genuster 7522910
make XdawnTransformer properly public
Genuster a57934b
add changelog entry
Genuster 7951a43
fix xdawntransformer docstring
Genuster b3a378d
more docstring adventures
Genuster 7b9ad11
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 268f54f
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 10afd78
fix docdict order
Genuster 5049e8c
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster 9ec6835
make all init arguments have default in ged transformer
Genuster 99f6e9c
Merge branch 'main' into base-GED
larsoner 9364b04
Merge branch 'main' into base-GED
larsoner 2b690ea
temporarily skip the problematic test
Genuster 0117163
unskip the test
Genuster b08cf81
fix unskip
Genuster e4e38c3
WIP: Test [actions ssh] [skip circle] [skip azp]
larsoner 3fc96c0
WIP: Test [actions ssh] [skip circle] [skip azp]
larsoner 4c1b8ff
WIP: Test [actions ssh] [skip circle] [skip azp]
larsoner 0bd74a5
WIP: Test more [actions ssh] [skip circle] [skip azp]
larsoner 67f498d
FIX: More [ci skip]
larsoner f2451ba
WIP: Tests
larsoner 23c37cb
FIX: States
larsoner 13821c7
Merge branch 'main' into base-GED
larsoner 35f5749
update versions and add stars
Genuster d713bfe
Merge branch 'base-GED' of https://github.com/Genuster/mne-python int…
Genuster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
"""Covariance estimation for GED transformers.""" | ||
|
||
# Authors: The MNE-Python contributors. | ||
# License: BSD-3-Clause | ||
# Copyright the MNE-Python contributors. | ||
|
||
import numpy as np | ||
|
||
from .._fiff.meas_info import Info, create_info | ||
from .._fiff.pick import _picks_to_idx | ||
from ..cov import Covariance, _compute_rank_raw_array, _regularized_covariance | ||
from ..defaults import _handle_default | ||
from ..filter import filter_data | ||
from ..rank import compute_rank | ||
from ..utils import _verbose_safe_false, logger | ||
|
||
|
||
def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info): | ||
"""Concatenate epochs before computing the covariance.""" | ||
_, n_channels, _ = x_class.shape | ||
|
||
x_class = x_class.transpose(1, 0, 2).reshape(n_channels, -1) | ||
cov = _regularized_covariance( | ||
x_class, | ||
reg=reg, | ||
method_params=cov_method_params, | ||
rank=rank, | ||
info=info, | ||
cov_kind=cov_kind, | ||
log_rank=log_rank, | ||
log_ch_type="data", | ||
) | ||
weight = x_class.shape[0] | ||
|
||
return cov, weight | ||
|
||
|
||
def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info): | ||
"""Mean of per-epoch covariances.""" | ||
name = reg if isinstance(reg, str) else "empirical" | ||
name += " with shrinkage" if isinstance(reg, float) else "" | ||
logger.info( | ||
f"Estimating {cov_kind + (' ' if cov_kind else '')}" | ||
f"covariance (average over epochs; {name.upper()})" | ||
) | ||
cov = sum( | ||
_regularized_covariance( | ||
this_X, | ||
reg=reg, | ||
method_params=cov_method_params, | ||
rank=rank, | ||
info=info, | ||
cov_kind=cov_kind, | ||
log_rank=log_rank and ii == 0, | ||
log_ch_type="data", | ||
verbose=_verbose_safe_false(), | ||
) | ||
for ii, this_X in enumerate(x_class) | ||
) | ||
cov /= len(x_class) | ||
weight = len(x_class) | ||
|
||
return cov, weight | ||
|
||
|
||
def _csp_estimate(X, y, reg, cov_method_params, cov_est, rank, norm_trace): | ||
_, n_channels, _ = X.shape | ||
classes_ = np.unique(y) | ||
if cov_est == "concat": | ||
cov_estimator = _concat_cov | ||
elif cov_est == "epoch": | ||
cov_estimator = _epoch_cov | ||
# Someday we could allow the user to pass this, then we wouldn't need to convert | ||
# but in the meantime they can use a pipeline with a scaler | ||
_info = create_info(n_channels, 1000.0, "mag") | ||
if isinstance(rank, dict): | ||
_rank = {"mag": sum(rank.values())} | ||
else: | ||
_rank = _compute_rank_raw_array( | ||
X.transpose(1, 0, 2).reshape(X.shape[1], -1), | ||
_info, | ||
rank=rank, | ||
scalings=None, | ||
log_ch_type="data", | ||
) | ||
|
||
covs = [] | ||
sample_weights = [] | ||
for ci, this_class in enumerate(classes_): | ||
cov, weight = cov_estimator( | ||
X[y == this_class], | ||
cov_kind=f"class={this_class}", | ||
log_rank=ci == 0, | ||
reg=reg, | ||
cov_method_params=cov_method_params, | ||
rank=_rank, | ||
info=_info, | ||
) | ||
|
||
if norm_trace: | ||
cov /= np.trace(cov) | ||
|
||
covs.append(cov) | ||
sample_weights.append(weight) | ||
|
||
covs = np.stack(covs) | ||
C_ref = covs.mean(0) | ||
|
||
return covs, C_ref, _info, _rank, dict(sample_weights=np.array(sample_weights)) | ||
|
||
|
||
def _xdawn_estimate( | ||
X, | ||
y, | ||
reg, | ||
cov_method_params, | ||
R=None, | ||
events=None, | ||
tmin=0, | ||
sfreq=1, | ||
info=None, | ||
rank="full", | ||
): | ||
from ..preprocessing.xdawn import _least_square_evoked | ||
|
||
if not isinstance(X, np.ndarray) or X.ndim != 3: | ||
raise ValueError("X must be 3D ndarray") | ||
|
||
classes = np.unique(y) | ||
|
||
# XXX Eventually this could be made to deal with rank deficiency properly | ||
# by exposing this "rank" parameter, but this will require refactoring | ||
# the linalg.eigh call to operate in the lower-dimension | ||
# subspace, then project back out. | ||
|
||
# Retrieve or compute whitening covariance | ||
if R is None: | ||
R = _regularized_covariance( | ||
np.hstack(X), reg, cov_method_params, info, rank=rank | ||
) | ||
elif isinstance(R, Covariance): | ||
R = R.data | ||
if not isinstance(R, np.ndarray) or ( | ||
not np.array_equal(R.shape, np.tile(X.shape[1], 2)) | ||
): | ||
raise ValueError( | ||
"R must be None, a covariance instance, " | ||
"or an array of shape (n_chans, n_chans)" | ||
) | ||
|
||
# Get prototype events | ||
if events is not None: | ||
evokeds, toeplitzs = _least_square_evoked(X, events, tmin, sfreq) | ||
else: | ||
evokeds, toeplitzs = list(), list() | ||
for c in classes: | ||
# Prototyped response for each class | ||
evokeds.append(np.mean(X[y == c, :, :], axis=0)) | ||
toeplitzs.append(1.0) | ||
|
||
covs = [] | ||
for evo, toeplitz in zip(evokeds, toeplitzs): | ||
# Estimate covariance matrix of the prototype response | ||
evo = np.dot(evo, toeplitz) | ||
evo_cov = _regularized_covariance(evo, reg, cov_method_params, info, rank=rank) | ||
covs.append(evo_cov) | ||
|
||
covs.append(R) | ||
C_ref = None | ||
rank = None | ||
info = None | ||
return covs, C_ref, info, rank, dict() | ||
|
||
|
||
def _ssd_estimate( | ||
X, | ||
y, | ||
reg, | ||
cov_method_params, | ||
info, | ||
picks, | ||
filt_params_signal, | ||
filt_params_noise, | ||
rank, | ||
): | ||
if isinstance(info, Info): | ||
sfreq = info["sfreq"] | ||
elif isinstance(info, float): # special case, mostly for testing | ||
sfreq = info | ||
info = create_info(X.shape[-2], sfreq, ch_types="eeg") | ||
picks = _picks_to_idx(info, picks, none="data", exclude="bads") | ||
X_aux = X[..., picks, :] | ||
X_signal = filter_data(X_aux, sfreq, **filt_params_signal) | ||
X_noise = filter_data(X_aux, sfreq, **filt_params_noise) | ||
X_noise -= X_signal | ||
if X.ndim == 3: | ||
X_signal = np.hstack(X_signal) | ||
X_noise = np.hstack(X_noise) | ||
|
||
# prevent rank change when computing cov with rank='full' | ||
S = _regularized_covariance( | ||
X_signal, | ||
reg=reg, | ||
method_params=cov_method_params, | ||
rank="full", | ||
info=info, | ||
) | ||
R = _regularized_covariance( | ||
X_noise, | ||
reg=reg, | ||
method_params=cov_method_params, | ||
rank="full", | ||
info=info, | ||
) | ||
covs = [S, R] | ||
C_ref = S | ||
|
||
all_ranks = list() | ||
for cov in covs: | ||
r = list( | ||
compute_rank( | ||
Covariance( | ||
cov, | ||
info.ch_names, | ||
list(), | ||
list(), | ||
0, | ||
verbose=_verbose_safe_false(), | ||
), | ||
rank, | ||
_handle_default("scalings_cov_rank", None), | ||
info, | ||
).values() | ||
)[0] | ||
all_ranks.append(r) | ||
rank = np.min(all_ranks) | ||
return covs, C_ref, info, rank, dict() | ||
|
||
|
||
def _spoc_estimate(X, y, reg, cov_method_params, rank): | ||
# Normalize target variable | ||
target = y.astype(np.float64) | ||
target -= target.mean() | ||
target /= target.std() | ||
|
||
n_epochs, n_channels = X.shape[:2] | ||
|
||
# Estimate single trial covariance | ||
covs = np.empty((n_epochs, n_channels, n_channels)) | ||
for ii, epoch in enumerate(X): | ||
covs[ii] = _regularized_covariance( | ||
epoch, | ||
reg=reg, | ||
method_params=cov_method_params, | ||
rank=rank, | ||
log_ch_type="data", | ||
log_rank=ii == 0, | ||
) | ||
|
||
S = np.mean(covs * target[:, np.newaxis, np.newaxis], axis=0) | ||
R = covs.mean(0) | ||
|
||
covs = [S, R] | ||
C_ref = None | ||
info = None | ||
return covs, C_ref, info, rank, dict() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
"""Eigenvalue eigenvector modifiers for GED transformers.""" | ||
|
||
# Authors: The MNE-Python contributors. | ||
# License: BSD-3-Clause | ||
# Copyright the MNE-Python contributors. | ||
|
||
import numpy as np | ||
|
||
|
||
def _compute_mutual_info(covs, sample_weights, evecs): | ||
class_probas = sample_weights / sample_weights.sum() | ||
|
||
mutual_info = [] | ||
for jj in range(evecs.shape[1]): | ||
aa, bb = 0, 0 | ||
for cov, prob in zip(covs, class_probas): | ||
tmp = np.dot(np.dot(evecs[:, jj].T, cov), evecs[:, jj]) | ||
aa += prob * np.log(np.sqrt(tmp)) | ||
bb += prob * (tmp**2 - 1) | ||
mi = -(aa + (3.0 / 16) * (bb**2)) | ||
mutual_info.append(mi) | ||
|
||
return mutual_info | ||
|
||
|
||
def _csp_mod(evals, evecs, covs, evecs_order, sample_weights): | ||
n_classes = sample_weights.shape[0] | ||
if evecs_order == "mutual_info" and n_classes > 2: | ||
mutual_info = _compute_mutual_info(covs, sample_weights, evecs) | ||
ix = np.argsort(mutual_info)[::-1] | ||
elif evecs_order == "mutual_info" and n_classes == 2: | ||
ix = np.argsort(np.abs(evals - 0.5))[::-1] | ||
elif evecs_order == "alternate" and n_classes == 2: | ||
i = np.argsort(evals) | ||
ix = np.empty_like(i) | ||
ix[1::2] = i[: len(i) // 2] | ||
ix[0::2] = i[len(i) // 2 :][::-1] | ||
if evals is not None: | ||
evals = evals[ix] | ||
evecs = evecs[:, ix] | ||
return evals, evecs | ||
|
||
|
||
def _xdawn_mod(evals, evecs, covs=None): | ||
evals, evecs = _sort_descending(evals, evecs) | ||
evecs /= np.linalg.norm(evecs, axis=0) | ||
return evals, evecs | ||
|
||
|
||
def _ssd_mod(evals, evecs, covs=None): | ||
evals, evecs = _sort_descending(evals, evecs) | ||
return evals, evecs | ||
|
||
|
||
def _spoc_mod(evals, evecs, covs=None): | ||
evals = evals.real | ||
evecs = evecs.real | ||
evals, evecs = _sort_descending(evals, evecs, by_abs=True) | ||
return evals, evecs | ||
|
||
|
||
def _sort_descending(evals, evecs, by_abs=False): | ||
if by_abs: | ||
ix = np.argsort(np.abs(evals))[::-1] | ||
else: | ||
ix = np.argsort(evals)[::-1] | ||
evals = evals[ix] | ||
evecs = evecs[:, ix] | ||
return evals, evecs |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.