Skip to content

Commit 7b316cb

Browse files
antoinecollaspre-commit-ci[bot]autofix-ci[bot]larsoner
authored
ENH: add interpolate_to method (mne-tools#13044)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
1 parent 64ed255 commit 7b316cb

File tree

6 files changed

+331
-2
lines changed

6 files changed

+331
-2
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add :meth:`mne.Evoked.interpolate_to` to allow interpolating EEG data to other montages, by :newcontrib:`Antoine Collas`.

doc/changes/names.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
.. _Anna Padee: https://github.com/apadee/
2525
.. _Annalisa Pascarella: https://www.iac.cnr.it/personale/annalisa-pascarella
2626
.. _Anne-Sophie Dubarry: https://github.com/annesodub
27+
.. _Antoine Collas: https://www.antoinecollas.fr
2728
.. _Antoine Gauthier: https://github.com/Okamille
2829
.. _Antti Rantala: https://github.com/Odingod
2930
.. _Apoorva Karekal: https://github.com/apoorva6262

doc/references.bib

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2514,3 +2514,11 @@ @article{OyamaEtAl2015
25142514
year = {2015},
25152515
pages = {24--36},
25162516
}
2517+
2518+
@inproceedings{MellotEtAl2024,
2519+
title = {Physics-informed and Unsupervised Riemannian Domain Adaptation for Machine Learning on Heterogeneous EEG Datasets},
2520+
author = {Mellot, Apolline and Collas, Antoine and Chevallier, Sylvain and Engemann, Denis and Gramfort, Alexandre},
2521+
booktitle = {Proceedings of the 32nd European Signal Processing Conference (EUSIPCO)},
2522+
year = {2024},
2523+
address = {Lyon, France}
2524+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
.. _ex-interpolate-to-any-montage:
3+
4+
======================================================
5+
Interpolate EEG data to any montage
6+
======================================================
7+
8+
This example demonstrates how to interpolate EEG channels to match a given montage.
9+
This can be useful for standardizing
10+
EEG channel layouts across different datasets (see :footcite:`MellotEtAl2024`).
11+
12+
- Using the field interpolation for EEG data.
13+
- Using the target montage "biosemi16".
14+
15+
In this example, the data from the original EEG channels will be
16+
interpolated onto the positions defined by the "biosemi16" montage.
17+
"""
18+
19+
# Authors: Antoine Collas <contact@antoinecollas.fr>
20+
# License: BSD-3-Clause
21+
# Copyright the MNE-Python contributors.
22+
23+
import matplotlib.pyplot as plt
24+
25+
import mne
26+
from mne.channels import make_standard_montage
27+
from mne.datasets import sample
28+
29+
print(__doc__)
30+
ylim = (-10, 10)
31+
32+
# %%
33+
# Load EEG data
34+
data_path = sample.data_path()
35+
eeg_file_path = data_path / "MEG" / "sample" / "sample_audvis-ave.fif"
36+
evoked = mne.read_evokeds(eeg_file_path, condition="Left Auditory", baseline=(None, 0))
37+
38+
# Select only EEG channels
39+
evoked.pick("eeg")
40+
41+
# Plot the original EEG layout
42+
evoked.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim))
43+
44+
# %%
45+
# Define the target montage
46+
standard_montage = make_standard_montage("biosemi16")
47+
48+
# %%
49+
# Use interpolate_to to project EEG data to the standard montage
50+
evoked_interpolated_spline = evoked.copy().interpolate_to(
51+
standard_montage, method="spline"
52+
)
53+
54+
# Plot the interpolated EEG layout
55+
evoked_interpolated_spline.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim))
56+
57+
# %%
58+
# Use interpolate_to to project EEG data to the standard montage
59+
evoked_interpolated_mne = evoked.copy().interpolate_to(standard_montage, method="MNE")
60+
61+
# Plot the interpolated EEG layout
62+
evoked_interpolated_mne.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim))
63+
64+
# %%
65+
# Comparing before and after interpolation
66+
fig, axs = plt.subplots(3, 1, figsize=(8, 6), constrained_layout=True)
67+
evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False, ylim=dict(eeg=ylim))
68+
axs[0].set_title("Original EEG Layout")
69+
evoked_interpolated_spline.plot(
70+
exclude=[], picks="eeg", axes=axs[1], show=False, ylim=dict(eeg=ylim)
71+
)
72+
axs[1].set_title("Interpolated to Standard 1020 Montage using spline interpolation")
73+
evoked_interpolated_mne.plot(
74+
exclude=[], picks="eeg", axes=axs[2], show=False, ylim=dict(eeg=ylim)
75+
)
76+
axs[2].set_title("Interpolated to Standard 1020 Montage using MNE interpolation")
77+
78+
# %%
79+
# References
80+
# ----------
81+
# .. footbibliography::

mne/channels/channels.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
pick_info,
4242
pick_types,
4343
)
44-
from .._fiff.proj import setup_proj
44+
from .._fiff.proj import _has_eeg_average_ref_proj, setup_proj
4545
from .._fiff.reference import add_reference_channels, set_eeg_reference
4646
from .._fiff.tag import _rename_list
4747
from ..bem import _check_origin
@@ -960,6 +960,162 @@ def interpolate_bads(
960960

961961
return self
962962

963+
def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0):
964+
"""Interpolate EEG data onto a new montage.
965+
966+
.. warning::
967+
Be careful, only EEG channels are interpolated. Other channel types are
968+
not interpolated.
969+
970+
Parameters
971+
----------
972+
sensors : DigMontage
973+
The target montage containing channel positions to interpolate onto.
974+
origin : array-like, shape (3,) | str
975+
Origin of the sphere in the head coordinate frame and in meters.
976+
Can be ``'auto'`` (default), which means a head-digitization-based
977+
origin fit.
978+
method : str
979+
Method to use for EEG channels.
980+
Supported methods are 'spline' (default) and 'MNE'.
981+
reg : float
982+
The regularization parameter for the interpolation method
983+
(only used when the method is 'spline').
984+
985+
Returns
986+
-------
987+
inst : instance of Raw, Epochs, or Evoked
988+
The instance with updated channel locations and data.
989+
990+
Notes
991+
-----
992+
This method is useful for standardizing EEG layouts across datasets.
993+
However, some attributes may be lost after interpolation.
994+
995+
.. versionadded:: 1.10.0
996+
"""
997+
from ..epochs import BaseEpochs, EpochsArray
998+
from ..evoked import Evoked, EvokedArray
999+
from ..forward._field_interpolation import _map_meg_or_eeg_channels
1000+
from ..io import RawArray
1001+
from ..io.base import BaseRaw
1002+
from .interpolation import _make_interpolation_matrix
1003+
from .montage import DigMontage
1004+
1005+
# Check that the method option is valid.
1006+
_check_option("method", method, ["spline", "MNE"])
1007+
_validate_type(sensors, DigMontage, "sensors")
1008+
1009+
# Get target positions from the montage
1010+
ch_pos = sensors.get_positions().get("ch_pos", {})
1011+
target_ch_names = list(ch_pos.keys())
1012+
if not target_ch_names:
1013+
raise ValueError(
1014+
"The provided sensors configuration has no channel positions."
1015+
)
1016+
1017+
# Get original channel order
1018+
orig_names = self.info["ch_names"]
1019+
1020+
# Identify EEG channel
1021+
picks_good_eeg = pick_types(self.info, meg=False, eeg=True, exclude="bads")
1022+
if len(picks_good_eeg) == 0:
1023+
raise ValueError("No good EEG channels available for interpolation.")
1024+
# Also get the full list of EEG channel indices (including bad channels)
1025+
picks_remove_eeg = pick_types(self.info, meg=False, eeg=True, exclude=[])
1026+
eeg_names_orig = [orig_names[i] for i in picks_remove_eeg]
1027+
1028+
# Identify non-EEG channels in original order
1029+
non_eeg_names_ordered = [ch for ch in orig_names if ch not in eeg_names_orig]
1030+
1031+
# Create destination info for new EEG channels
1032+
sfreq = self.info["sfreq"]
1033+
info_interp = create_info(
1034+
ch_names=target_ch_names,
1035+
sfreq=sfreq,
1036+
ch_types=["eeg"] * len(target_ch_names),
1037+
)
1038+
info_interp.set_montage(sensors)
1039+
info_interp["bads"] = [ch for ch in self.info["bads"] if ch in target_ch_names]
1040+
# Do not assign "projs" directly.
1041+
1042+
# Compute the interpolation mapping
1043+
if method == "spline":
1044+
origin_val = _check_origin(origin, self.info)
1045+
pos_from = self.info._get_channel_positions(picks_good_eeg) - origin_val
1046+
pos_to = np.stack(list(ch_pos.values()), axis=0)
1047+
1048+
def _check_pos_sphere(pos):
1049+
d = np.linalg.norm(pos, axis=-1)
1050+
d_norm = np.mean(d / np.mean(d))
1051+
if np.abs(1.0 - d_norm) > 0.1:
1052+
warn("Your spherical fit is poor; interpolation may be inaccurate.")
1053+
1054+
_check_pos_sphere(pos_from)
1055+
_check_pos_sphere(pos_to)
1056+
mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg)
1057+
1058+
else:
1059+
assert method == "MNE"
1060+
info_eeg = pick_info(self.info, picks_good_eeg)
1061+
# If the original info has an average EEG reference projector but
1062+
# the destination info does not,
1063+
# update info_interp via a temporary RawArray.
1064+
if _has_eeg_average_ref_proj(self.info) and not _has_eeg_average_ref_proj(
1065+
info_interp
1066+
):
1067+
# Create dummy data: shape (n_channels, 1)
1068+
temp_data = np.zeros((len(info_interp["ch_names"]), 1))
1069+
temp_raw = RawArray(temp_data, info_interp, first_samp=0)
1070+
# Using the public API, add an average reference projector.
1071+
temp_raw.set_eeg_reference(
1072+
ref_channels="average", projection=True, verbose=False
1073+
)
1074+
# Extract the updated info.
1075+
info_interp = temp_raw.info
1076+
mapping = _map_meg_or_eeg_channels(
1077+
info_eeg, info_interp, mode="accurate", origin=origin
1078+
)
1079+
1080+
# Interpolate EEG data
1081+
data_good = self.get_data(picks=picks_good_eeg)
1082+
data_interp = mapping @ data_good
1083+
1084+
# Create a new instance for the interpolated EEG channels
1085+
# TODO: Creating a new instance leads to a loss of information.
1086+
# We should consider updating the existing instance in the future
1087+
# by 1) drop channels, 2) add channels, 3) re-order channels.
1088+
if isinstance(self, BaseRaw):
1089+
inst_interp = RawArray(data_interp, info_interp, first_samp=self.first_samp)
1090+
elif isinstance(self, BaseEpochs):
1091+
inst_interp = EpochsArray(data_interp, info_interp)
1092+
else:
1093+
assert isinstance(self, Evoked)
1094+
inst_interp = EvokedArray(data_interp, info_interp)
1095+
1096+
# Merge only if non-EEG channels exist
1097+
if not non_eeg_names_ordered:
1098+
return inst_interp
1099+
1100+
inst_non_eeg = self.copy().pick(non_eeg_names_ordered).load_data()
1101+
inst_out = inst_non_eeg.add_channels([inst_interp], force_update_info=True)
1102+
1103+
# Reorder channels
1104+
# Insert the entire new EEG block at the position of the first EEG channel.
1105+
orig_names_arr = np.array(orig_names)
1106+
mask_eeg = np.isin(orig_names_arr, eeg_names_orig)
1107+
if mask_eeg.any():
1108+
first_eeg_index = np.where(mask_eeg)[0][0]
1109+
pre = orig_names_arr[:first_eeg_index]
1110+
new_eeg = np.array(info_interp["ch_names"])
1111+
post = orig_names_arr[first_eeg_index:]
1112+
post = post[~np.isin(orig_names_arr[first_eeg_index:], eeg_names_orig)]
1113+
new_order = np.concatenate((pre, new_eeg, post)).tolist()
1114+
else:
1115+
new_order = orig_names
1116+
inst_out.reorder_channels(new_order)
1117+
return inst_out
1118+
9631119

9641120
@verbose
9651121
def rename_channels(info, mapping, allow_duplicates=False, *, verbose=None):

mne/channels/tests/test_interpolation.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mne import Epochs, pick_channels, pick_types, read_events
1313
from mne._fiff.constants import FIFF
1414
from mne._fiff.proj import _has_eeg_average_ref_proj
15-
from mne.channels import make_dig_montage
15+
from mne.channels import make_dig_montage, make_standard_montage
1616
from mne.channels.interpolation import _make_interpolation_matrix
1717
from mne.datasets import testing
1818
from mne.io import RawArray, read_raw_ctf, read_raw_fif, read_raw_nirx
@@ -439,3 +439,85 @@ def test_method_str():
439439
raw.interpolate_bads(method="spline")
440440
raw.pick("eeg", exclude=())
441441
raw.interpolate_bads(method="spline")
442+
443+
444+
@pytest.mark.parametrize("montage_name", ["biosemi16", "standard_1020"])
445+
@pytest.mark.parametrize("method", ["spline", "MNE"])
446+
@pytest.mark.parametrize("data_type", ["raw", "epochs", "evoked"])
447+
def test_interpolate_to_eeg(montage_name, method, data_type):
448+
"""Test the interpolate_to method for EEG for raw, epochs, and evoked."""
449+
# Load EEG data
450+
raw, epochs_eeg = _load_data("eeg")
451+
epochs_eeg = epochs_eeg.copy()
452+
453+
# Load data for raw
454+
raw.load_data()
455+
456+
# Create a target montage
457+
montage = make_standard_montage(montage_name)
458+
459+
# Prepare data to interpolate to
460+
if data_type == "raw":
461+
inst = raw.copy()
462+
elif data_type == "epochs":
463+
inst = epochs_eeg.copy()
464+
elif data_type == "evoked":
465+
inst = epochs_eeg.average()
466+
shape = list(inst._data.shape)
467+
orig_total = len(inst.info["ch_names"])
468+
n_eeg_orig = len(pick_types(inst.info, eeg=True))
469+
470+
# Assert first and last channels are not EEG
471+
if data_type == "raw":
472+
ch_types = inst.get_channel_types()
473+
assert ch_types[0] != "eeg"
474+
assert ch_types[-1] != "eeg"
475+
476+
# Record the names and data of the first and last channels.
477+
if data_type == "raw":
478+
first_name = inst.info["ch_names"][0]
479+
last_name = inst.info["ch_names"][-1]
480+
data_first = inst._data[..., 0, :].copy()
481+
data_last = inst._data[..., -1, :].copy()
482+
483+
# Interpolate the EEG channels.
484+
inst_interp = inst.copy().interpolate_to(montage, method=method)
485+
486+
# Check that the new channel names include the montage channels.
487+
assert set(montage.ch_names).issubset(set(inst_interp.info["ch_names"]))
488+
# Check that the overall channel order is changed.
489+
assert inst.info["ch_names"] != inst_interp.info["ch_names"]
490+
491+
# Check that the data shape is as expected.
492+
new_nchan_expected = orig_total - n_eeg_orig + len(montage.ch_names)
493+
expected_shape = (new_nchan_expected, shape[-1])
494+
if len(shape) == 3:
495+
expected_shape = (shape[0],) + expected_shape
496+
assert inst_interp._data.shape == expected_shape
497+
498+
# Verify that the first and last channels retain their positions.
499+
if data_type == "raw":
500+
assert inst_interp.info["ch_names"][0] == first_name
501+
assert inst_interp.info["ch_names"][-1] == last_name
502+
503+
# Verify that the data for the first and last channels is unchanged.
504+
if data_type == "raw":
505+
np.testing.assert_allclose(
506+
inst_interp._data[..., 0, :],
507+
data_first,
508+
err_msg="Data for the first non-EEG channel has changed.",
509+
)
510+
np.testing.assert_allclose(
511+
inst_interp._data[..., -1, :],
512+
data_last,
513+
err_msg="Data for the last non-EEG channel has changed.",
514+
)
515+
516+
# Validate that bad channels are carried over.
517+
# Mark the first non eeg channel as bad
518+
all_ch = inst_interp.info["ch_names"]
519+
eeg_ch = [all_ch[i] for i in pick_types(inst_interp.info, eeg=True)]
520+
bads = [ch for ch in all_ch if ch not in eeg_ch][:1]
521+
inst.info["bads"] = bads
522+
inst_interp = inst.copy().interpolate_to(montage, method=method)
523+
assert inst_interp.info["bads"] == bads

0 commit comments

Comments
 (0)