Skip to content

Commit 9ed326f

Browse files
anaradanovicpre-commit-ci[bot]drammocknordmelarsoner
authored
adding unify channels to preprocessing draft (mne-tools#12014)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel McCloy <dan@mccloy.info> Co-authored-by: nordme <nordme@uw.edu> Co-authored-by: nordme <38704848+nordme@users.noreply.github.com> Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
1 parent 927f88d commit 9ed326f

File tree

4 files changed

+134
-0
lines changed

4 files changed

+134
-0
lines changed

doc/preprocessing.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Projections:
5757
get_builtin_ch_adjacencies
5858
read_ch_adjacency
5959
equalize_channels
60+
unify_bad_channels
6061
rename_channels
6162
generate_2d_layout
6263
make_1020_channel_selections

mne/channels/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"_EEG_SELECTIONS",
2323
"_divide_to_regions",
2424
"get_builtin_ch_adjacencies",
25+
"unify_bad_channels",
2526
],
2627
"layout": [
2728
"Layout",

mne/channels/channels.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# Andrew Dykstra <andrew.r.dykstra@gmail.com>
66
# Teon Brooks <teon.brooks@gmail.com>
77
# Daniel McCloy <dan.mccloy@gmail.com>
8+
# Ana Radanovic <radanovica@protonmail.com>
9+
# Erica Peterson <nordme@uw.edu>
810
#
911
# License: BSD-3-Clause
1012

@@ -206,6 +208,83 @@ def equalize_channels(instances, copy=True, verbose=None):
206208
return equalized_instances
207209

208210

211+
def unify_bad_channels(insts):
212+
"""Unify bad channels across a list of instances.
213+
214+
All instances must be of the same type and have matching channel names and channel
215+
order. The ``.info["bads"]`` of each instance will be set to the union of
216+
``.info["bads"]`` across all instances.
217+
218+
Parameters
219+
----------
220+
insts : list
221+
List of instances (:class:`~mne.io.Raw`, :class:`~mne.Epochs`,
222+
:class:`~mne.Evoked`, :class:`~mne.time_frequency.Spectrum`,
223+
:class:`~mne.time_frequency.EpochsSpectrum`) across which to unify bad channels.
224+
225+
Returns
226+
-------
227+
insts : list
228+
List of instances with bad channels unified across instances.
229+
230+
See Also
231+
--------
232+
mne.channels.equalize_channels
233+
mne.channels.rename_channels
234+
mne.channels.combine_channels
235+
236+
Notes
237+
-----
238+
This function modifies the instances in-place.
239+
240+
.. versionadded:: 1.6
241+
"""
242+
from ..io import BaseRaw
243+
from ..epochs import Epochs
244+
from ..evoked import Evoked
245+
from ..time_frequency.spectrum import BaseSpectrum
246+
247+
# ensure input is list-like
248+
_validate_type(insts, (list, tuple), "insts")
249+
# ensure non-empty
250+
if len(insts) == 0:
251+
raise ValueError("insts must not be empty")
252+
# ensure all insts are MNE objects, and all the same type
253+
inst_type = type(insts[0])
254+
valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum)
255+
for inst in insts:
256+
_validate_type(inst, valid_types, "each object in insts")
257+
if type(inst) != inst_type:
258+
raise ValueError("All insts must be the same type")
259+
260+
# ensure all insts have the same channels and channel order
261+
ch_names = insts[0].ch_names
262+
for inst in insts[1:]:
263+
dif = set(inst.ch_names) ^ set(ch_names)
264+
if len(dif):
265+
raise ValueError(
266+
"Channels do not match across the objects in insts. Consider calling "
267+
"equalize_channels before calling this function."
268+
)
269+
elif inst.ch_names != ch_names:
270+
raise ValueError(
271+
"Channel names are sorted differently across instances. Please use "
272+
"mne.channels.equalize_channels."
273+
)
274+
275+
# collect bads as dict keys so that insertion order is preserved, then cast to list
276+
all_bads = dict()
277+
for inst in insts:
278+
all_bads.update(dict.fromkeys(inst.info["bads"]))
279+
all_bads = list(all_bads)
280+
281+
# update bads on all instances
282+
for inst in insts:
283+
inst.info["bads"] = all_bads
284+
285+
return insts
286+
287+
209288
class ReferenceMixin(MontageMixin):
210289
"""Mixin class for Raw, Evoked, Epochs."""
211290

mne/channels/tests/test_unify_bads.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
from mne.channels import unify_bad_channels
3+
4+
5+
def test_error_raising(raw, epochs):
6+
"""Tests input checking."""
7+
with pytest.raises(TypeError, match=r"must be an instance of list"):
8+
unify_bad_channels("bad input")
9+
with pytest.raises(ValueError, match=r"insts must not be empty"):
10+
unify_bad_channels([])
11+
with pytest.raises(TypeError, match=r"each object in insts must be an instance of"):
12+
unify_bad_channels(["bad_instance"])
13+
with pytest.raises(ValueError, match=r"same type"):
14+
unify_bad_channels([raw, epochs])
15+
with pytest.raises(ValueError, match=r"Channels do not match across"):
16+
raw_alt1 = raw.copy()
17+
raw_alt1.drop_channels(raw.info["ch_names"][-1])
18+
unify_bad_channels([raw, raw_alt1]) # ch diff preserving order
19+
with pytest.raises(ValueError, match=r"sorted differently"):
20+
raw_alt2 = raw.copy()
21+
new_order = [raw.ch_names[-1]] + raw.ch_names[:-1]
22+
raw_alt2.reorder_channels(new_order)
23+
unify_bad_channels([raw, raw_alt2])
24+
25+
26+
def test_bads_compilation(raw):
27+
"""Tests that bads are compiled properly.
28+
29+
Tests two cases: a) single instance passed to function with an existing
30+
bad, and b) multiple instances passed to function with varying compilation
31+
scenarios including empty bads, unique bads, and partially duplicated bads
32+
listed out-of-order.
33+
34+
Only the Raw instance type is tested, since bad channel implementation is
35+
controlled across instance types with a MixIn class.
36+
"""
37+
assert raw.info["bads"] == []
38+
chns = raw.ch_names[:3]
39+
no_bad = raw.copy()
40+
one_bad = raw.copy()
41+
one_bad.info["bads"] = [chns[1]]
42+
three_bad = raw.copy()
43+
three_bad.info["bads"] = chns
44+
# scenario 1: single instance passed with actual bads
45+
s_out = unify_bad_channels([one_bad])
46+
assert len(s_out) == 1, len(s_out)
47+
assert s_out[0].info["bads"] == [chns[1]], (s_out[0].info["bads"], chns[1])
48+
# scenario 2: multiple instances passed
49+
m_out = unify_bad_channels([one_bad, no_bad, three_bad])
50+
assert len(m_out) == 3, len(m_out)
51+
expected_order = [chns[1], chns[0], chns[2]]
52+
for inst in m_out:
53+
assert inst.info["bads"] == expected_order

0 commit comments

Comments
 (0)