Skip to content

Commit 9f31cf3

Browse files
Adding nan method to interpolate channels (mne-tools#12027)
Co-authored-by: Daniel McCloy <dan@mccloy.info>
1 parent 04e05d4 commit 9f31cf3

File tree

3 files changed

+86
-12
lines changed

3 files changed

+86
-12
lines changed

mne/channels/channels.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -830,17 +830,19 @@ def interpolate_bads(
830830
.. versionadded:: 0.17
831831
method : dict | None
832832
Method to use for each channel type.
833-
Currently only the key ``"eeg"`` has multiple options:
833+
All channel types support "nan".
834+
The key ``"eeg"`` has two additional options:
834835
835836
- ``"spline"`` (default)
836837
Use spherical spline interpolation.
837838
- ``"MNE"``
838839
Use minimum-norm projection to a sphere and back.
839840
This is the method used for MEG channels.
840841
841-
The value for ``"meg"`` is ``"MNE"``, and the value for
842-
``"fnirs"`` is ``"nearest"``. The default (None) is thus an alias
843-
for::
842+
The default value for ``"meg"`` is ``"MNE"``, and the default value
843+
for ``"fnirs"`` is ``"nearest"``.
844+
845+
The default (None) is thus an alias for::
844846
845847
method=dict(meg="MNE", eeg="spline", fnirs="nearest")
846848
@@ -858,6 +860,10 @@ def interpolate_bads(
858860
Notes
859861
-----
860862
.. versionadded:: 0.9.0
863+
864+
.. warning::
865+
Be careful when using ``method="nan"``; the default value
866+
``reset_bads=True`` may not be what you want.
861867
"""
862868
from .interpolation import (
863869
_interpolate_bads_eeg,
@@ -869,9 +875,31 @@ def interpolate_bads(
869875
method = _handle_default("interpolation_method", method)
870876
for key in method:
871877
_check_option("method[key]", key, ("meg", "eeg", "fnirs"))
872-
_check_option("method['eeg']", method["eeg"], ("spline", "MNE"))
873-
_check_option("method['meg']", method["meg"], ("MNE",))
874-
_check_option("method['fnirs']", method["fnirs"], ("nearest",))
878+
_check_option(
879+
"method['eeg']",
880+
method["eeg"],
881+
(
882+
"spline",
883+
"MNE",
884+
"nan",
885+
),
886+
)
887+
_check_option(
888+
"method['meg']",
889+
method["meg"],
890+
(
891+
"MNE",
892+
"nan",
893+
),
894+
)
895+
_check_option(
896+
"method['fnirs']",
897+
method["fnirs"],
898+
(
899+
"nearest",
900+
"nan",
901+
),
902+
)
875903

876904
if len(self.info["bads"]) == 0:
877905
warn("No bad channels to interpolate. Doing nothing...")
@@ -884,11 +912,18 @@ def interpolate_bads(
884912
else:
885913
eeg_mne = True
886914
_interpolate_bads_meeg(
887-
self, mode=mode, origin=origin, eeg=eeg_mne, exclude=exclude
915+
self, mode=mode, origin=origin, eeg=eeg_mne, exclude=exclude, method=method
888916
)
889-
_interpolate_bads_nirs(self, exclude=exclude)
917+
_interpolate_bads_nirs(self, exclude=exclude, method=method["fnirs"])
890918

891919
if reset_bads is True:
920+
if "nan" in method.values():
921+
warn(
922+
"interpolate_bads was called with method='nan' and "
923+
"reset_bads=True. Consider setting reset_bads=False so that the "
924+
"nan-containing channels can be easily excluded from later "
925+
"computations."
926+
)
892927
self.info["bads"] = [ch for ch in self.info["bads"] if ch in exclude]
893928

894929
return self

mne/channels/interpolation.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Authors: Denis Engemann <denis.engemann@gmail.com>
2+
# Ana Radanovic <radanovica@protonmail.com>
23
#
34
# License: BSD-3-Clause
45

@@ -191,10 +192,14 @@ def _interpolate_bads_meeg(
191192
eeg=True,
192193
ref_meg=False,
193194
exclude=(),
195+
*,
196+
method=None,
194197
verbose=None,
195198
):
196199
from ..forward import _map_meg_or_eeg_channels
197200

201+
if method is None:
202+
method = {"meg": "MNE", "eeg": "MNE"}
198203
bools = dict(meg=meg, eeg=eeg)
199204
info = _simplify_info(inst.info)
200205
for ch_type, do in bools.items():
@@ -210,6 +215,12 @@ def _interpolate_bads_meeg(
210215
continue
211216
# select the bad channels to be interpolated
212217
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])
218+
219+
if method[ch_type] == "nan":
220+
inst._data[picks_bad] = np.nan
221+
continue
222+
223+
# do MNE based interpolation
213224
if ch_type == "eeg":
214225
picks_to = picks_type
215226
bad_sel = np.isin(picks_type, picks_bad)
@@ -243,7 +254,7 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None):
243254
chs = [inst.info["chs"][i] for i in picks_nirs]
244255
locs3d = np.array([ch["loc"][:3] for ch in chs])
245256

246-
_check_option("fnirs_method", method, ["nearest"])
257+
_check_option("fnirs_method", method, ["nearest", "nan"])
247258

248259
if method == "nearest":
249260
dist = pdist(locs3d)
@@ -258,7 +269,10 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None):
258269
# Find closest remaining channels for same frequency
259270
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
260271
inst._data[bad] = inst._data[closest_idx]
261-
262-
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude]
272+
else:
273+
assert method == "nan"
274+
inst._data[picks_bad] = np.nan
275+
# TODO: this seems like a bug because it does not respect reset_bads
276+
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude]
263277

264278
return inst

mne/channels/tests/test_interpolation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from mne._fiff.proj import _has_eeg_average_ref_proj
1818
from mne.utils import _record_warnings
1919

20+
2021
base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data"
2122
raw_fname = base_dir / "test_raw.fif"
2223
event_name = base_dir / "test-eve.fif"
@@ -324,3 +325,27 @@ def test_interpolation_nirs():
324325
assert raw_haemo.info["bads"] == ["S1_D2 hbo", "S1_D2 hbr"]
325326
raw_haemo.interpolate_bads()
326327
assert raw_haemo.info["bads"] == []
328+
329+
330+
def test_nan_interpolation(raw):
331+
"""Test 'nan' method for interpolating bads."""
332+
ch_to_interp = [raw.ch_names[1]] # don't use channel 0 (type is IAS not MEG)
333+
raw.info["bads"] = ch_to_interp
334+
335+
# test that warning appears for reset_bads = True
336+
with pytest.warns(RuntimeWarning, match="Consider setting reset_bads=False"):
337+
raw.interpolate_bads(method="nan", reset_bads=True)
338+
339+
# despite warning, interpolation still happened, make sure the channel is NaN
340+
bad_chs = raw.get_data(ch_to_interp)
341+
assert np.isnan(bad_chs).all()
342+
343+
# make sure reset_bads=False works as expected
344+
raw.info["bads"] = ch_to_interp
345+
raw.interpolate_bads(method="nan", reset_bads=False)
346+
assert raw.info["bads"] == ch_to_interp
347+
348+
# make sure other channels are untouched
349+
raw.drop_channels(ch_to_interp)
350+
good_chs = raw.get_data()
351+
assert np.isfinite(good_chs).all()

0 commit comments

Comments
 (0)